-
Notifications
You must be signed in to change notification settings - Fork 246
fix(agentic): search API now queries all requested memory_types inste… #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,6 +94,13 @@ | |
| MemoryType.EPISODIC_MEMORY: EpisodicMemoryEsRepository, | ||
| } | ||
|
|
||
| # MemoryType -> Milvus Repository mapping | ||
| MILVUS_REPO_MAP = { | ||
| MemoryType.FORESIGHT: ForesightMilvusRepository, | ||
| MemoryType.EVENT_LOG: EventLogMilvusRepository, | ||
| MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository, | ||
| } | ||
|
|
||
|
|
||
| @dataclass | ||
| class EventLogCandidate: | ||
|
|
@@ -299,7 +306,7 @@ async def retrieve_mem_keyword( | |
| """Keyword-based memory retrieval""" | ||
| start_time = time.perf_counter() | ||
| memory_type = ( | ||
| retrieve_mem_request.memory_types[0].value | ||
| ','.join(mt.value for mt in retrieve_mem_request.memory_types) | ||
| if retrieve_mem_request.memory_types | ||
| else 'unknown' | ||
| ) | ||
|
|
@@ -340,7 +347,7 @@ async def get_keyword_search_results( | |
| """Keyword search with stage-level metrics""" | ||
| stage_start = time.perf_counter() | ||
| memory_type = ( | ||
| retrieve_mem_request.memory_types[0].value | ||
| ','.join(mt.value for mt in retrieve_mem_request.memory_types) | ||
| if retrieve_mem_request.memory_types | ||
| else 'unknown' | ||
| ) | ||
|
|
@@ -375,32 +382,35 @@ async def get_keyword_search_results( | |
| if end_time is not None: | ||
| date_range["lte"] = end_time | ||
|
|
||
| mem_type = memory_types[0] | ||
|
|
||
| repo_class = ES_REPO_MAP.get(mem_type) | ||
| if not repo_class: | ||
| logger.warning(f"Unsupported memory_type: {mem_type}") | ||
| return [] | ||
| all_results = [] | ||
| for mem_type in memory_types: | ||
| repo_class = ES_REPO_MAP.get(mem_type) | ||
| if not repo_class: | ||
| logger.info( | ||
| f"Skipping unsupported memory_type for keyword search: {mem_type}" | ||
| ) | ||
| continue | ||
|
|
||
| es_repo = get_bean_by_type(repo_class) | ||
| logger.debug(f"Using {repo_class.__name__} for {mem_type}") | ||
| es_repo = get_bean_by_type(repo_class) | ||
| logger.debug(f"Using {repo_class.__name__} for {mem_type}") | ||
|
|
||
| results = await es_repo.multi_search( | ||
| query=query_words, | ||
| user_id=user_id, | ||
| group_id=group_id, | ||
| size=top_k, | ||
| from_=0, | ||
| date_range=date_range, | ||
| ) | ||
| results = await es_repo.multi_search( | ||
| query=query_words, | ||
| user_id=user_id, | ||
| group_id=group_id, | ||
| size=top_k, | ||
| from_=0, | ||
| date_range=date_range, | ||
| ) | ||
|
|
||
| # Mark memory_type, search_source, and unified score | ||
| if results: | ||
| for r in results: | ||
| r['memory_type'] = mem_type.value | ||
| r['_search_source'] = RetrieveMethod.KEYWORD.value | ||
| r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' | ||
| r['score'] = r.get('_score', 0.0) # Unified score field | ||
| # Mark memory_type, search_source, and unified score | ||
| if results: | ||
| for r in results: | ||
| r['memory_type'] = mem_type.value | ||
| r['_search_source'] = RetrieveMethod.KEYWORD.value | ||
| r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' | ||
| r['score'] = r.get('_score', 0.0) # Unified score field | ||
| all_results.extend(results) | ||
|
|
||
| # Record stage metrics | ||
| record_retrieve_stage( | ||
|
|
@@ -410,7 +420,7 @@ async def get_keyword_search_results( | |
| duration_seconds=time.perf_counter() - stage_start, | ||
| ) | ||
|
|
||
| return results or [] | ||
| return all_results | ||
| except Exception as e: | ||
| record_retrieve_stage( | ||
| retrieve_method=retrieve_method, | ||
|
|
@@ -434,7 +444,7 @@ async def retrieve_mem_vector( | |
| """Vector-based memory retrieval""" | ||
| start_time = time.perf_counter() | ||
| memory_type = ( | ||
| retrieve_mem_request.memory_types[0].value | ||
| ','.join(mt.value for mt in retrieve_mem_request.memory_types) | ||
| if retrieve_mem_request.memory_types | ||
| else 'unknown' | ||
| ) | ||
|
|
@@ -474,7 +484,7 @@ async def get_vector_search_results( | |
| ) -> List[Dict[str, Any]]: | ||
| """Vector search with stage-level metrics (embedding + milvus_search)""" | ||
| memory_type = ( | ||
| retrieve_mem_request.memory_types[0].value | ||
| ','.join(mt.value for mt in retrieve_mem_request.memory_types) | ||
| if retrieve_mem_request.memory_types | ||
| else 'unknown' | ||
| ) | ||
|
|
@@ -497,7 +507,7 @@ async def get_vector_search_results( | |
| top_k = retrieve_mem_request.top_k | ||
| start_time = retrieve_mem_request.start_time | ||
| end_time = retrieve_mem_request.end_time | ||
| mem_type = retrieve_mem_request.memory_types[0] | ||
| memory_types = retrieve_mem_request.memory_types | ||
|
|
||
| logger.debug( | ||
| f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}" | ||
|
|
@@ -506,7 +516,7 @@ async def get_vector_search_results( | |
| # Get vectorization service | ||
| vectorize_service = get_vectorize_service() | ||
|
|
||
| # Convert query text to vector (embedding stage) | ||
| # Convert query text to vector (embedding stage) - only once for all types | ||
| logger.debug(f"Starting to vectorize query text: {query}") | ||
| embedding_start = time.perf_counter() | ||
| query_vector = await vectorize_service.get_embedding(query) | ||
|
|
@@ -521,21 +531,9 @@ async def get_vector_search_results( | |
| f"Query text vectorization completed, vector dimension: {len(query_vector_list)}" | ||
| ) | ||
|
|
||
| # Select Milvus repository based on memory type | ||
| match mem_type: | ||
| case MemoryType.FORESIGHT: | ||
| milvus_repo = get_bean_by_type(ForesightMilvusRepository) | ||
| case MemoryType.EVENT_LOG: | ||
| milvus_repo = get_bean_by_type(EventLogMilvusRepository) | ||
| case MemoryType.EPISODIC_MEMORY: | ||
| milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) | ||
| case _: | ||
| raise ValueError(f"Unsupported memory type: {mem_type}") | ||
|
|
||
| # Handle time range filter conditions | ||
| start_time_dt = None | ||
| end_time_dt = None | ||
| current_time_dt = None | ||
|
|
||
| if start_time is not None: | ||
| start_time_dt = ( | ||
|
|
@@ -553,62 +551,75 @@ async def get_vector_search_results( | |
| else: | ||
| end_time_dt = end_time | ||
|
|
||
| # Handle foresight time range (only valid for foresight) | ||
| if mem_type == MemoryType.FORESIGHT: | ||
| if retrieve_mem_request.start_time: | ||
| start_time_dt = from_iso_format(retrieve_mem_request.start_time) | ||
| if retrieve_mem_request.end_time: | ||
| end_time_dt = from_iso_format(retrieve_mem_request.end_time) | ||
| if retrieve_mem_request.current_time: | ||
| current_time_dt = from_iso_format(retrieve_mem_request.current_time) | ||
|
|
||
| # Call Milvus vector search (pass different parameters based on memory type) | ||
| milvus_start = time.perf_counter() | ||
| if mem_type == MemoryType.FORESIGHT: | ||
| # Foresight: supports time range and validity filtering, supports radius parameter | ||
| search_results = await milvus_repo.vector_search( | ||
| query_vector=query_vector_list, | ||
| user_id=user_id, | ||
| group_id=group_id, | ||
| start_time=start_time_dt, | ||
| end_time=end_time_dt, | ||
| current_time=current_time_dt, | ||
| limit=top_k, | ||
| score_threshold=0.0, | ||
| radius=retrieve_mem_request.radius, | ||
| ) | ||
| else: | ||
| # Episodic memory and event log: use timestamp filtering, supports radius parameter | ||
| search_results = await milvus_repo.vector_search( | ||
| query_vector=query_vector_list, | ||
| user_id=user_id, | ||
| group_id=group_id, | ||
| start_time=start_time_dt, | ||
| end_time=end_time_dt, | ||
| limit=top_k, | ||
| score_threshold=0.0, | ||
| radius=retrieve_mem_request.radius, | ||
| # Search all supported memory types | ||
| all_results = [] | ||
| for mem_type in memory_types: | ||
| repo_class = MILVUS_REPO_MAP.get(mem_type) | ||
| if not repo_class: | ||
| logger.info( | ||
| f"Skipping unsupported memory_type for vector search: {mem_type}" | ||
| ) | ||
| continue | ||
|
|
||
| milvus_repo = get_bean_by_type(repo_class) | ||
|
|
||
| # Handle foresight-specific time range | ||
| type_start_time = start_time_dt | ||
| type_end_time = end_time_dt | ||
| current_time_dt = None | ||
| if mem_type == MemoryType.FORESIGHT: | ||
| if retrieve_mem_request.start_time: | ||
| type_start_time = from_iso_format( | ||
| retrieve_mem_request.start_time | ||
| ) | ||
| if retrieve_mem_request.end_time: | ||
| type_end_time = from_iso_format( | ||
| retrieve_mem_request.end_time | ||
| ) | ||
| if retrieve_mem_request.current_time: | ||
| current_time_dt = from_iso_format( | ||
| retrieve_mem_request.current_time | ||
| ) | ||
|
Comment on lines
+570
to
+582
|
||
|
|
||
| # Call Milvus vector search | ||
| milvus_start = time.perf_counter() | ||
| if mem_type == MemoryType.FORESIGHT: | ||
| search_results = await milvus_repo.vector_search( | ||
| query_vector=query_vector_list, | ||
| user_id=user_id, | ||
| group_id=group_id, | ||
| start_time=type_start_time, | ||
| end_time=type_end_time, | ||
| current_time=current_time_dt, | ||
| limit=top_k, | ||
| score_threshold=0.0, | ||
| radius=retrieve_mem_request.radius, | ||
| ) | ||
| else: | ||
| search_results = await milvus_repo.vector_search( | ||
| query_vector=query_vector_list, | ||
| user_id=user_id, | ||
| group_id=group_id, | ||
| start_time=start_time_dt, | ||
| end_time=end_time_dt, | ||
| limit=top_k, | ||
| score_threshold=0.0, | ||
| radius=retrieve_mem_request.radius, | ||
| ) | ||
|
Comment on lines
+598
to
+608
|
||
| record_retrieve_stage( | ||
| retrieve_method=retrieve_method, | ||
| stage='milvus_search', | ||
| memory_type=mem_type.value, | ||
| duration_seconds=time.perf_counter() - milvus_start, | ||
| ) | ||
| record_retrieve_stage( | ||
| retrieve_method=retrieve_method, | ||
| stage='milvus_search', | ||
| memory_type=memory_type, | ||
| duration_seconds=time.perf_counter() - milvus_start, | ||
| ) | ||
|
|
||
| for r in search_results: | ||
| r['memory_type'] = mem_type.value | ||
| r['_search_source'] = RetrieveMethod.VECTOR.value | ||
| # Milvus already uses 'score', no need to rename | ||
| for r in search_results: | ||
| r['memory_type'] = mem_type.value | ||
| r['_search_source'] = RetrieveMethod.VECTOR.value | ||
| all_results.extend(search_results) | ||
|
|
||
| return search_results | ||
| return all_results | ||
| except Exception as e: | ||
| record_retrieve_stage( | ||
| retrieve_method=retrieve_method, | ||
| stage=RetrieveMethod.VECTOR.value, | ||
| memory_type=memory_type, | ||
| duration_seconds=time.perf_counter() - milvus_start, | ||
| ) | ||
| record_retrieve_error( | ||
| retrieve_method=retrieve_method, | ||
| stage=RetrieveMethod.VECTOR.value, | ||
|
|
@@ -625,7 +636,7 @@ async def retrieve_mem_hybrid( | |
| """Hybrid memory retrieval: keyword + vector + rerank""" | ||
| start_time = time.perf_counter() | ||
| memory_type = ( | ||
| retrieve_mem_request.memory_types[0].value | ||
| ','.join(mt.value for mt in retrieve_mem_request.memory_types) | ||
| if retrieve_mem_request.memory_types | ||
| else 'unknown' | ||
| ) | ||
|
|
@@ -700,7 +711,9 @@ async def _search_hybrid( | |
| ) -> List[Dict]: | ||
| """Core hybrid search: keyword + vector + rerank, returns flat list""" | ||
| memory_type = ( | ||
| request.memory_types[0].value if request.memory_types else 'unknown' | ||
| ','.join(mt.value for mt in request.memory_types) | ||
| if request.memory_types | ||
| else 'unknown' | ||
| ) | ||
| # Run keyword and vector search concurrently | ||
| kw_results, vec_results = await asyncio.gather( | ||
|
|
@@ -723,7 +736,9 @@ async def _search_rrf( | |
| ) -> List[Dict]: | ||
| """Core RRF search: keyword + vector + RRF fusion, returns flat list""" | ||
| memory_type = ( | ||
| request.memory_types[0].value if request.memory_types else 'unknown' | ||
| ','.join(mt.value for mt in request.memory_types) | ||
| if request.memory_types | ||
| else 'unknown' | ||
| ) | ||
|
|
||
| # Run keyword and vector search concurrently | ||
|
|
@@ -766,7 +781,11 @@ async def _to_response( | |
| """Convert flat hits list to grouped RetrieveMemResponse""" | ||
| user_id = req.user_id if req else "" | ||
| source_type = req.retrieve_method.value | ||
| memory_type = req.memory_types[0].value | ||
| memory_type = ( | ||
| ','.join(mt.value for mt in req.memory_types) | ||
| if req.memory_types | ||
| else 'unknown' | ||
| ) | ||
|
|
||
| if not hits: | ||
| return RetrieveMemResponse( | ||
|
|
@@ -809,7 +828,7 @@ async def retrieve_mem_rrf( | |
| """RRF-based memory retrieval: keyword + vector + RRF fusion""" | ||
| start_time = time.perf_counter() | ||
| memory_type = ( | ||
| retrieve_mem_request.memory_types[0].value | ||
| ','.join(mt.value for mt in retrieve_mem_request.memory_types) | ||
| if retrieve_mem_request.memory_types | ||
| else 'unknown' | ||
| ) | ||
|
|
@@ -855,7 +874,11 @@ async def retrieve_mem_agentic( | |
| req = retrieve_mem_request # alias | ||
| top_k = req.top_k | ||
| config = AgenticConfig() | ||
| memory_type = req.memory_types[0].value if req.memory_types else 'unknown' | ||
| memory_type = ( | ||
| ','.join(mt.value for mt in req.memory_types) | ||
| if req.memory_types | ||
| else 'unknown' | ||
| ) | ||
|
|
||
| try: | ||
| llm_provider = LLMProvider( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
top_klimit is applied per memory type instead of across all types. In the keyword search loop (lines 386-413), each memory type query usessize=top_k(line 401), which means if you requesttop_k=5and search across 3 memory types, you could get up to 15 results before reranking/RRF.Similarly, in the vector search loop (lines 556-619), each memory type search uses
limit=top_k(lines 594, 605), which could also return up toN * top_kresults where N is the number of memory types.This behavior may be intentional to ensure diverse results across all types before the final rerank/RRF step reduces them to the requested
top_k. However, this should be clarified:top_kresults per type for better diversity, this is working as expectedtop_kbefore rerank/RRF, the limit should be adjusted (e.g.,top_k // len(memory_types)ortop_k * some_multiplier / len(memory_types))The current behavior means with 3 memory types and top_k=5, you get up to 15 intermediate results, which are then reranked/merged down to 5 final results. This seems reasonable for result diversity, but should be documented.