diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 0906e9a6ba..0e8b9bf1b2 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -534,18 +534,36 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str): logger.info('Search memory response received.') memory_events: list[MemoryEntry] = [] - async for retrieved_memory in retrieved_memories_iterator: - # TODO: add more complex error handling - logger.debug('Retrieved memory: %s', retrieved_memory) - memory_events.append( - MemoryEntry( - author='user', - content=types.Content( - parts=[types.Part(text=retrieved_memory.memory.fact)], - role='user', - ), - timestamp=retrieved_memory.memory.update_time.isoformat(), + try: + async for retrieved_memory in retrieved_memories_iterator: + try: + memory = retrieved_memory.memory + if memory is None: + logger.warning('Skipping memory entry with missing memory object.') + continue + fact = memory.fact + if not fact: + logger.warning('Skipping memory entry with empty or missing fact.') + continue + update_time = memory.update_time + memory_events.append( + MemoryEntry( + author='user', + content=types.Content( + parts=[types.Part(text=fact)], + role='user', + ), + timestamp=update_time.isoformat() if update_time else None, + ) + ) + except AttributeError: + logger.warning( + 'Skipping malformed memory entry: %s', retrieved_memory ) + except Exception: + logger.exception( + 'Error while iterating memory results. Returning %d partial results.', + len(memory_events), ) return SearchMemoryResponse(memories=memory_events) diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index b6eead6465..e41e797fa2 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -1009,7 +1009,6 @@ async def test_search_memory_empty_results(mock_vertexai_client): assert len(result.memories) == 0 -@pytest.mark.asyncio async def test_search_memory_uses_async_client_path(): sync_client = mock.MagicMock() sync_client.agent_engines.memories.retrieve.side_effect = AssertionError( @@ -1039,3 +1038,110 @@ async def test_search_memory_uses_async_client_path(): similarity_search_params={'search_query': 'query'}, ) sync_client.agent_engines.memories.retrieve.assert_not_called() + + +@pytest.mark.asyncio +async def test_search_memory_skips_entry_with_none_memory(mock_vertexai_client): + bad_entry = mock.MagicMock() + bad_entry.memory = None + + good_entry = mock.MagicMock() + good_entry.memory.fact = 'good fact' + good_entry.memory.update_time = datetime.datetime(2024, 1, 1) + + mock_vertexai_client.agent_engines.memories.retrieve.return_value = ( + _AsyncListIterator([bad_entry, good_entry]) + ) + memory_service = mock_vertex_ai_memory_bank_service() + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == 'good fact' + + +@pytest.mark.asyncio +async def test_search_memory_skips_entry_with_empty_fact(mock_vertexai_client): + for empty_fact in [None, '']: + bad_entry = mock.MagicMock() + bad_entry.memory.fact = empty_fact + bad_entry.memory.update_time = datetime.datetime(2024, 1, 1) + + mock_vertexai_client.agent_engines.memories.retrieve.return_value = ( + _AsyncListIterator([bad_entry]) + ) + memory_service = mock_vertex_ai_memory_bank_service() + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' + ) + + assert len(result.memories) == 0 + + +@pytest.mark.asyncio +async def test_search_memory_handles_missing_update_time(mock_vertexai_client): + entry = mock.MagicMock() + entry.memory.fact = 'some fact' + entry.memory.update_time = None + + mock_vertexai_client.agent_engines.memories.retrieve.return_value = ( + _AsyncListIterator([entry]) + ) + memory_service = mock_vertex_ai_memory_bank_service() + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == 'some fact' + assert result.memories[0].timestamp is None + + +@pytest.mark.asyncio +async def test_search_memory_skips_malformed_entry(mock_vertexai_client): + malformed = mock.MagicMock(spec=[]) # no attributes → AttributeError + + good_entry = mock.MagicMock() + good_entry.memory.fact = 'good fact' + good_entry.memory.update_time = datetime.datetime(2024, 1, 1) + + mock_vertexai_client.agent_engines.memories.retrieve.return_value = ( + _AsyncListIterator([malformed, good_entry]) + ) + memory_service = mock_vertex_ai_memory_bank_service() + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == 'good fact' + + +@pytest.mark.asyncio +async def test_search_memory_returns_partial_results_on_iterator_error( + mock_vertexai_client, +): + good_entry = mock.MagicMock() + good_entry.memory.fact = 'good fact' + good_entry.memory.update_time = datetime.datetime(2024, 1, 1) + + async def failing_async_iterator(): + yield good_entry + raise RuntimeError('API stream error') + + mock_vertexai_client.agent_engines.memories.retrieve.return_value = ( + failing_async_iterator() + ) + memory_service = mock_vertex_ai_memory_bank_service() + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query' + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == 'good fact'