Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 123 additions & 100 deletions src/agentic_layer/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@
MemoryType.EPISODIC_MEMORY: EpisodicMemoryEsRepository,
}

# MemoryType -> Milvus Repository mapping
MILVUS_REPO_MAP = {
MemoryType.FORESIGHT: ForesightMilvusRepository,
MemoryType.EVENT_LOG: EventLogMilvusRepository,
MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository,
}


@dataclass
class EventLogCandidate:
Expand Down Expand Up @@ -299,7 +306,7 @@ async def retrieve_mem_keyword(
"""Keyword-based memory retrieval"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand Down Expand Up @@ -340,7 +347,7 @@ async def get_keyword_search_results(
"""Keyword search with stage-level metrics"""
stage_start = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand Down Expand Up @@ -375,32 +382,35 @@ async def get_keyword_search_results(
if end_time is not None:
date_range["lte"] = end_time

mem_type = memory_types[0]

repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.warning(f"Unsupported memory_type: {mem_type}")
return []
all_results = []
for mem_type in memory_types:
repo_class = ES_REPO_MAP.get(mem_type)
if not repo_class:
logger.info(
f"Skipping unsupported memory_type for keyword search: {mem_type}"
)
continue

es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")
es_repo = get_bean_by_type(repo_class)
logger.debug(f"Using {repo_class.__name__} for {mem_type}")

results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)
results = await es_repo.multi_search(
query=query_words,
user_id=user_id,
group_id=group_id,
size=top_k,
from_=0,
date_range=date_range,
)

# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
# Mark memory_type, search_source, and unified score
if results:
for r in results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.KEYWORD.value
r['id'] = r.get('_id', '') # Unify ES '_id' to 'id'
r['score'] = r.get('_score', 0.0) # Unified score field
all_results.extend(results)
Comment on lines +386 to +413
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The top_k limit is applied per memory type instead of across all types. In the keyword search loop (lines 386-413), each memory type query uses size=top_k (line 401), which means if you request top_k=5 and search across 3 memory types, you could get up to 15 results before reranking/RRF.

Similarly, in the vector search loop (lines 556-619), each memory type search uses limit=top_k (lines 594, 605), which could also return up to N * top_k results where N is the number of memory types.

This behavior may be intentional to ensure diverse results across all types before the final rerank/RRF step reduces them to the requested top_k. However, this should be clarified:

  1. If the intention is to get top_k results per type for better diversity, this is working as expected
  2. If the intention is to limit total results to top_k before rerank/RRF, the limit should be adjusted (e.g., top_k // len(memory_types) or top_k * some_multiplier / len(memory_types))

The current behavior means with 3 memory types and top_k=5, you get up to 15 intermediate results, which are then reranked/merged down to 5 final results. This seems reasonable for result diversity, but should be documented.

Copilot uses AI. Check for mistakes.

# Record stage metrics
record_retrieve_stage(
Expand All @@ -410,7 +420,7 @@ async def get_keyword_search_results(
duration_seconds=time.perf_counter() - stage_start,
)

return results or []
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
Expand All @@ -434,7 +444,7 @@ async def retrieve_mem_vector(
"""Vector-based memory retrieval"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand Down Expand Up @@ -474,7 +484,7 @@ async def get_vector_search_results(
) -> List[Dict[str, Any]]:
"""Vector search with stage-level metrics (embedding + milvus_search)"""
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand All @@ -497,7 +507,7 @@ async def get_vector_search_results(
top_k = retrieve_mem_request.top_k
start_time = retrieve_mem_request.start_time
end_time = retrieve_mem_request.end_time
mem_type = retrieve_mem_request.memory_types[0]
memory_types = retrieve_mem_request.memory_types

logger.debug(
f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}"
Expand All @@ -506,7 +516,7 @@ async def get_vector_search_results(
# Get vectorization service
vectorize_service = get_vectorize_service()

# Convert query text to vector (embedding stage)
# Convert query text to vector (embedding stage) - only once for all types
logger.debug(f"Starting to vectorize query text: {query}")
embedding_start = time.perf_counter()
query_vector = await vectorize_service.get_embedding(query)
Expand All @@ -521,21 +531,9 @@ async def get_vector_search_results(
f"Query text vectorization completed, vector dimension: {len(query_vector_list)}"
)

# Select Milvus repository based on memory type
match mem_type:
case MemoryType.FORESIGHT:
milvus_repo = get_bean_by_type(ForesightMilvusRepository)
case MemoryType.EVENT_LOG:
milvus_repo = get_bean_by_type(EventLogMilvusRepository)
case MemoryType.EPISODIC_MEMORY:
milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository)
case _:
raise ValueError(f"Unsupported memory type: {mem_type}")

# Handle time range filter conditions
start_time_dt = None
end_time_dt = None
current_time_dt = None

if start_time is not None:
start_time_dt = (
Expand All @@ -553,62 +551,75 @@ async def get_vector_search_results(
else:
end_time_dt = end_time

# Handle foresight time range (only valid for foresight)
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
start_time_dt = from_iso_format(retrieve_mem_request.start_time)
if retrieve_mem_request.end_time:
end_time_dt = from_iso_format(retrieve_mem_request.end_time)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(retrieve_mem_request.current_time)

# Call Milvus vector search (pass different parameters based on memory type)
milvus_start = time.perf_counter()
if mem_type == MemoryType.FORESIGHT:
# Foresight: supports time range and validity filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
# Episodic memory and event log: use timestamp filtering, supports radius parameter
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
# Search all supported memory types
all_results = []
for mem_type in memory_types:
repo_class = MILVUS_REPO_MAP.get(mem_type)
if not repo_class:
logger.info(
f"Skipping unsupported memory_type for vector search: {mem_type}"
)
continue

milvus_repo = get_bean_by_type(repo_class)

# Handle foresight-specific time range
type_start_time = start_time_dt
type_end_time = end_time_dt
current_time_dt = None
if mem_type == MemoryType.FORESIGHT:
if retrieve_mem_request.start_time:
type_start_time = from_iso_format(
retrieve_mem_request.start_time
)
if retrieve_mem_request.end_time:
type_end_time = from_iso_format(
retrieve_mem_request.end_time
)
if retrieve_mem_request.current_time:
current_time_dt = from_iso_format(
retrieve_mem_request.current_time
)
Comment on lines +570 to +582
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent time parsing for FORESIGHT memory type. When the memory type is FORESIGHT, the time fields are re-parsed from retrieve_mem_request.start_time and retrieve_mem_request.end_time (lines 572-578), but this re-parsing doesn't include the special handling for date-only end_time format that was applied in the initial parsing (lines 546-552).

In the initial parsing at lines 546-552, if end_time is a date-only string (length 10), it's adjusted to the end of day (23:59:59). However, the FORESIGHT-specific re-parsing at lines 575-578 doesn't include this logic, which could lead to inconsistent behavior.

Additionally, this results in redundant parsing - these time values are parsed twice for FORESIGHT types (once at lines 538-552, then again at lines 571-578).

Recommendation: For FORESIGHT, the code should either reuse the already-parsed start_time_dt and end_time_dt values, or ensure the re-parsing includes all the same logic as the initial parsing.

Copilot uses AI. Check for mistakes.

# Call Milvus vector search
milvus_start = time.perf_counter()
if mem_type == MemoryType.FORESIGHT:
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=type_start_time,
end_time=type_end_time,
current_time=current_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
else:
search_results = await milvus_repo.vector_search(
query_vector=query_vector_list,
user_id=user_id,
group_id=group_id,
start_time=start_time_dt,
end_time=end_time_dt,
limit=top_k,
score_threshold=0.0,
radius=retrieve_mem_request.radius,
)
Comment on lines +598 to +608
Copy link

Copilot AI Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-FORESIGHT memory types are using the wrong time variables. At lines 603-604, non-FORESIGHT memory types are passing start_time_dt and end_time_dt to the vector search, but these variables are defined at the function scope level, outside the loop.

However, the loop variable type_start_time and type_end_time were initialized at lines 567-568 specifically to hold the correct time values for each iteration. For non-FORESIGHT types, type_start_time and type_end_time are set to start_time_dt and end_time_dt, but then the code ignores these loop variables and directly uses start_time_dt and end_time_dt anyway.

While this works correctly in the current implementation, it's inconsistent with the FORESIGHT branch (which uses type_start_time and type_end_time at lines 591-592), and the loop variables serve no purpose for non-FORESIGHT types.

Recommendation: For consistency and maintainability, non-FORESIGHT memory types should also use type_start_time and type_end_time instead of start_time_dt and end_time_dt.

Copilot uses AI. Check for mistakes.
record_retrieve_stage(
retrieve_method=retrieve_method,
stage='milvus_search',
memory_type=mem_type.value,
duration_seconds=time.perf_counter() - milvus_start,
)
record_retrieve_stage(
retrieve_method=retrieve_method,
stage='milvus_search',
memory_type=memory_type,
duration_seconds=time.perf_counter() - milvus_start,
)

for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
# Milvus already uses 'score', no need to rename
for r in search_results:
r['memory_type'] = mem_type.value
r['_search_source'] = RetrieveMethod.VECTOR.value
all_results.extend(search_results)

return search_results
return all_results
except Exception as e:
record_retrieve_stage(
retrieve_method=retrieve_method,
stage=RetrieveMethod.VECTOR.value,
memory_type=memory_type,
duration_seconds=time.perf_counter() - milvus_start,
)
record_retrieve_error(
retrieve_method=retrieve_method,
stage=RetrieveMethod.VECTOR.value,
Expand All @@ -625,7 +636,7 @@ async def retrieve_mem_hybrid(
"""Hybrid memory retrieval: keyword + vector + rerank"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand Down Expand Up @@ -700,7 +711,9 @@ async def _search_hybrid(
) -> List[Dict]:
"""Core hybrid search: keyword + vector + rerank, returns flat list"""
memory_type = (
request.memory_types[0].value if request.memory_types else 'unknown'
','.join(mt.value for mt in request.memory_types)
if request.memory_types
else 'unknown'
)
# Run keyword and vector search concurrently
kw_results, vec_results = await asyncio.gather(
Expand All @@ -723,7 +736,9 @@ async def _search_rrf(
) -> List[Dict]:
"""Core RRF search: keyword + vector + RRF fusion, returns flat list"""
memory_type = (
request.memory_types[0].value if request.memory_types else 'unknown'
','.join(mt.value for mt in request.memory_types)
if request.memory_types
else 'unknown'
)

# Run keyword and vector search concurrently
Expand Down Expand Up @@ -766,7 +781,11 @@ async def _to_response(
"""Convert flat hits list to grouped RetrieveMemResponse"""
user_id = req.user_id if req else ""
source_type = req.retrieve_method.value
memory_type = req.memory_types[0].value
memory_type = (
','.join(mt.value for mt in req.memory_types)
if req.memory_types
else 'unknown'
)

if not hits:
return RetrieveMemResponse(
Expand Down Expand Up @@ -809,7 +828,7 @@ async def retrieve_mem_rrf(
"""RRF-based memory retrieval: keyword + vector + RRF fusion"""
start_time = time.perf_counter()
memory_type = (
retrieve_mem_request.memory_types[0].value
','.join(mt.value for mt in retrieve_mem_request.memory_types)
if retrieve_mem_request.memory_types
else 'unknown'
)
Expand Down Expand Up @@ -855,7 +874,11 @@ async def retrieve_mem_agentic(
req = retrieve_mem_request # alias
top_k = req.top_k
config = AgenticConfig()
memory_type = req.memory_types[0].value if req.memory_types else 'unknown'
memory_type = (
','.join(mt.value for mt in req.memory_types)
if req.memory_types
else 'unknown'
)

try:
llm_provider = LLMProvider(
Expand Down
Loading