diff --git a/src/agentic_layer/memory_manager.py b/src/agentic_layer/memory_manager.py index df42b05a..eea9807b 100644 --- a/src/agentic_layer/memory_manager.py +++ b/src/agentic_layer/memory_manager.py @@ -94,6 +94,13 @@ MemoryType.EPISODIC_MEMORY: EpisodicMemoryEsRepository, } +# MemoryType -> Milvus Repository mapping +MILVUS_REPO_MAP = { + MemoryType.FORESIGHT: ForesightMilvusRepository, + MemoryType.EVENT_LOG: EventLogMilvusRepository, + MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository, +} + @dataclass class EventLogCandidate: @@ -299,7 +306,7 @@ async def retrieve_mem_keyword( """Keyword-based memory retrieval""" start_time = time.perf_counter() memory_type = ( - retrieve_mem_request.memory_types[0].value + ','.join(mt.value for mt in retrieve_mem_request.memory_types) if retrieve_mem_request.memory_types else 'unknown' ) @@ -340,7 +347,7 @@ async def get_keyword_search_results( """Keyword search with stage-level metrics""" stage_start = time.perf_counter() memory_type = ( - retrieve_mem_request.memory_types[0].value + ','.join(mt.value for mt in retrieve_mem_request.memory_types) if retrieve_mem_request.memory_types else 'unknown' ) @@ -375,32 +382,35 @@ async def get_keyword_search_results( if end_time is not None: date_range["lte"] = end_time - mem_type = memory_types[0] - - repo_class = ES_REPO_MAP.get(mem_type) - if not repo_class: - logger.warning(f"Unsupported memory_type: {mem_type}") - return [] + all_results = [] + for mem_type in memory_types: + repo_class = ES_REPO_MAP.get(mem_type) + if not repo_class: + logger.info( + f"Skipping unsupported memory_type for keyword search: {mem_type}" + ) + continue - es_repo = get_bean_by_type(repo_class) - logger.debug(f"Using {repo_class.__name__} for {mem_type}") + es_repo = get_bean_by_type(repo_class) + logger.debug(f"Using {repo_class.__name__} for {mem_type}") - results = await es_repo.multi_search( - query=query_words, - user_id=user_id, - group_id=group_id, - size=top_k, - from_=0, - date_range=date_range, - ) + results = await es_repo.multi_search( + query=query_words, + user_id=user_id, + group_id=group_id, + size=top_k, + from_=0, + date_range=date_range, + ) - # Mark memory_type, search_source, and unified score - if results: - for r in results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.KEYWORD.value - r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' - r['score'] = r.get('_score', 0.0) # Unified score field + # Mark memory_type, search_source, and unified score + if results: + for r in results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.KEYWORD.value + r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' + r['score'] = r.get('_score', 0.0) # Unified score field + all_results.extend(results) # Record stage metrics record_retrieve_stage( @@ -410,7 +420,7 @@ async def get_keyword_search_results( duration_seconds=time.perf_counter() - stage_start, ) - return results or [] + return all_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, @@ -434,7 +444,7 @@ async def retrieve_mem_vector( """Vector-based memory retrieval""" start_time = time.perf_counter() memory_type = ( - retrieve_mem_request.memory_types[0].value + ','.join(mt.value for mt in retrieve_mem_request.memory_types) if retrieve_mem_request.memory_types else 'unknown' ) @@ -474,7 +484,7 @@ async def get_vector_search_results( ) -> List[Dict[str, Any]]: """Vector search with stage-level metrics (embedding + milvus_search)""" memory_type = ( - retrieve_mem_request.memory_types[0].value + ','.join(mt.value for mt in retrieve_mem_request.memory_types) if retrieve_mem_request.memory_types else 'unknown' ) @@ -497,7 +507,7 @@ async def get_vector_search_results( top_k = retrieve_mem_request.top_k start_time = retrieve_mem_request.start_time end_time = retrieve_mem_request.end_time - mem_type = retrieve_mem_request.memory_types[0] + memory_types = retrieve_mem_request.memory_types logger.debug( f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}" @@ -506,7 +516,7 @@ async def get_vector_search_results( # Get vectorization service vectorize_service = get_vectorize_service() - # Convert query text to vector (embedding stage) + # Convert query text to vector (embedding stage) - only once for all types logger.debug(f"Starting to vectorize query text: {query}") embedding_start = time.perf_counter() query_vector = await vectorize_service.get_embedding(query) @@ -521,21 +531,9 @@ async def get_vector_search_results( f"Query text vectorization completed, vector dimension: {len(query_vector_list)}" ) - # Select Milvus repository based on memory type - match mem_type: - case MemoryType.FORESIGHT: - milvus_repo = get_bean_by_type(ForesightMilvusRepository) - case MemoryType.EVENT_LOG: - milvus_repo = get_bean_by_type(EventLogMilvusRepository) - case MemoryType.EPISODIC_MEMORY: - milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) - case _: - raise ValueError(f"Unsupported memory type: {mem_type}") - # Handle time range filter conditions start_time_dt = None end_time_dt = None - current_time_dt = None if start_time is not None: start_time_dt = ( @@ -553,62 +551,75 @@ async def get_vector_search_results( else: end_time_dt = end_time - # Handle foresight time range (only valid for foresight) - if mem_type == MemoryType.FORESIGHT: - if retrieve_mem_request.start_time: - start_time_dt = from_iso_format(retrieve_mem_request.start_time) - if retrieve_mem_request.end_time: - end_time_dt = from_iso_format(retrieve_mem_request.end_time) - if retrieve_mem_request.current_time: - current_time_dt = from_iso_format(retrieve_mem_request.current_time) - - # Call Milvus vector search (pass different parameters based on memory type) - milvus_start = time.perf_counter() - if mem_type == MemoryType.FORESIGHT: - # Foresight: supports time range and validity filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - current_time=current_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, - ) - else: - # Episodic memory and event log: use timestamp filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, + # Search all supported memory types + all_results = [] + for mem_type in memory_types: + repo_class = MILVUS_REPO_MAP.get(mem_type) + if not repo_class: + logger.info( + f"Skipping unsupported memory_type for vector search: {mem_type}" + ) + continue + + milvus_repo = get_bean_by_type(repo_class) + + # Handle foresight-specific time range + type_start_time = start_time_dt + type_end_time = end_time_dt + current_time_dt = None + if mem_type == MemoryType.FORESIGHT: + if retrieve_mem_request.start_time: + type_start_time = from_iso_format( + retrieve_mem_request.start_time + ) + if retrieve_mem_request.end_time: + type_end_time = from_iso_format( + retrieve_mem_request.end_time + ) + if retrieve_mem_request.current_time: + current_time_dt = from_iso_format( + retrieve_mem_request.current_time + ) + + # Call Milvus vector search + milvus_start = time.perf_counter() + if mem_type == MemoryType.FORESIGHT: + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=type_start_time, + end_time=type_end_time, + current_time=current_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) + else: + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=start_time_dt, + end_time=end_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) + record_retrieve_stage( + retrieve_method=retrieve_method, + stage='milvus_search', + memory_type=mem_type.value, + duration_seconds=time.perf_counter() - milvus_start, ) - record_retrieve_stage( - retrieve_method=retrieve_method, - stage='milvus_search', - memory_type=memory_type, - duration_seconds=time.perf_counter() - milvus_start, - ) - for r in search_results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.VECTOR.value - # Milvus already uses 'score', no need to rename + for r in search_results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.VECTOR.value + all_results.extend(search_results) - return search_results + return all_results except Exception as e: - record_retrieve_stage( - retrieve_method=retrieve_method, - stage=RetrieveMethod.VECTOR.value, - memory_type=memory_type, - duration_seconds=time.perf_counter() - milvus_start, - ) record_retrieve_error( retrieve_method=retrieve_method, stage=RetrieveMethod.VECTOR.value, @@ -625,7 +636,7 @@ async def retrieve_mem_hybrid( """Hybrid memory retrieval: keyword + vector + rerank""" start_time = time.perf_counter() memory_type = ( - retrieve_mem_request.memory_types[0].value + ','.join(mt.value for mt in retrieve_mem_request.memory_types) if retrieve_mem_request.memory_types else 'unknown' ) @@ -700,7 +711,9 @@ async def _search_hybrid( ) -> List[Dict]: """Core hybrid search: keyword + vector + rerank, returns flat list""" memory_type = ( - request.memory_types[0].value if request.memory_types else 'unknown' + ','.join(mt.value for mt in request.memory_types) + if request.memory_types + else 'unknown' ) # Run keyword and vector search concurrently kw_results, vec_results = await asyncio.gather( @@ -723,7 +736,9 @@ async def _search_rrf( ) -> List[Dict]: """Core RRF search: keyword + vector + RRF fusion, returns flat list""" memory_type = ( - request.memory_types[0].value if request.memory_types else 'unknown' + ','.join(mt.value for mt in request.memory_types) + if request.memory_types + else 'unknown' ) # Run keyword and vector search concurrently @@ -766,7 +781,11 @@ async def _to_response( """Convert flat hits list to grouped RetrieveMemResponse""" user_id = req.user_id if req else "" source_type = req.retrieve_method.value - memory_type = req.memory_types[0].value + memory_type = ( + ','.join(mt.value for mt in req.memory_types) + if req.memory_types + else 'unknown' + ) if not hits: return RetrieveMemResponse( @@ -809,7 +828,7 @@ async def retrieve_mem_rrf( """RRF-based memory retrieval: keyword + vector + RRF fusion""" start_time = time.perf_counter() memory_type = ( - retrieve_mem_request.memory_types[0].value + ','.join(mt.value for mt in retrieve_mem_request.memory_types) if retrieve_mem_request.memory_types else 'unknown' ) @@ -855,7 +874,11 @@ async def retrieve_mem_agentic( req = retrieve_mem_request # alias top_k = req.top_k config = AgenticConfig() - memory_type = req.memory_types[0].value if req.memory_types else 'unknown' + memory_type = ( + ','.join(mt.value for mt in req.memory_types) + if req.memory_types + else 'unknown' + ) try: llm_provider = LLMProvider(