diff --git a/.env.template b/.env.template index 48fe88a..4e8d34a 100644 --- a/.env.template +++ b/.env.template @@ -26,5 +26,10 @@ ELASTIC_BASE_URL= ELASTIC_USERNAME= ELASTIC_PASSWORD= PAGE_SIZE=1000 + +# Fusion weights (result ranking) +FUSION_VECTOR_WEIGHT=0.6 +FUSION_KS_WEIGHT=0.4 + GCS_BUCKET= GCS_PREFIX= diff --git a/backend/agents.py b/backend/agents.py index 155ce0a..aa5ddab 100644 --- a/backend/agents.py +++ b/backend/agents.py @@ -14,6 +14,13 @@ from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search from retrieval import get_retriever +# Fusion weights configurable via environment variables +VECTOR_WEIGHT = float(os.getenv("FUSION_VECTOR_WEIGHT", "0.6")) +KS_WEIGHT = float(os.getenv("FUSION_KS_WEIGHT", "0.4")) + +# Defaults preserve existing behavior (vector=0.6, keyword search=0.4) +if any(w < 0 for w in (VECTOR_WEIGHT, KS_WEIGHT)): + raise ValueError("Fusion weights must be non-negative") # LLM (Gemini) client setup try: @@ -443,14 +450,14 @@ def fuse_results(state: AgentState) -> AgentState: for res in vector_results: if isinstance(res, dict): doc_id = res.get("id") or res.get("_id") or f"vec_{len(combined)}" - combined[doc_id] = {**res, "final_score": res.get("similarity", 0) * 0.6} + combined[doc_id] = {**res, "final_score": res.get("similarity", 0) * VECTOR_WEIGHT} for res in ks_results: if isinstance(res, dict): doc_id = res.get("_id") or res.get("id") or f"ks_{len(combined)}" if doc_id in combined: - combined[doc_id]["final_score"] += res.get("_score", 0) * 0.4 + combined[doc_id]["final_score"] += res.get("_score", 0) * KS_WEIGHT else: - combined[doc_id] = {**res, "final_score": res.get("_score", 0) * 0.4} + combined[doc_id] = {**res, "final_score": res.get("_score", 0) * KS_WEIGHT} all_sorted = sorted(combined.values(), key=lambda x: x.get("final_score", 0), reverse=True) logger.info( "Results summary: KS=%d, Vector=%d, Combined=%d",