From 36ab10601dc17a46c84f7ad7b662581c59c93dc9 Mon Sep 17 00:00:00 2001 From: Zohaib Date: Tue, 17 Mar 2026 01:53:21 +0500 Subject: [PATCH 1/4] feat: re-implement metadata reranking with strict log and min-max bounds --- backend/ks_search_tool.py | 322 ++++++++++++++++++++++---- backend/tests/test_metadata_rerank.py | 94 ++++++++ 2 files changed, 369 insertions(+), 47 deletions(-) create mode 100644 backend/tests/test_metadata_rerank.py diff --git a/backend/ks_search_tool.py b/backend/ks_search_tool.py index 3004a02..95315d3 100644 --- a/backend/ks_search_tool.py +++ b/backend/ks_search_tool.py @@ -8,12 +8,157 @@ import re from urllib.parse import urlparse from difflib import SequenceMatcher +import math + +# --- Query Expansion for Neuroscience Terms --- +QUERY_SYNONYMS = { + "mouse brain": [ + "Rattus norvegicus", + "somatosensory cortex", + "cortex", + "hippocampus", + ], + "memory": ["hippocampus", "synaptic plasticity"], + "hippocampus": ["CA1", "CA3", "dentate gyrus"], + "eeg": ["electroencephalography"], + "fmri": ["functional magnetic resonance imaging", "BOLD"], +} + + +def expand_query(query: str) -> str: + if not query: + return "" + query_lower = query.lower() + expanded = [query_lower] + added_terms = set(expanded) + + for phrase, synonyms in QUERY_SYNONYMS.items(): + if phrase in query_lower: + for syn in synonyms: + if syn not in added_terms: + expanded.append(syn) + added_terms.add(syn) + + for word in query_lower.split(): + if word in QUERY_SYNONYMS: + for syn in QUERY_SYNONYMS[word]: + if syn not in added_terms: + expanded.append(syn) + added_terms.add(syn) + + return " OR ".join([f'"{t}"' if " " in t else t for t in expanded]) + + +# --- Metadata Normalized Reranking --- +def rerank_results_using_metadata(results: List[dict]) -> List[dict]: + """ + Safely boosts semantic/keyword search scores using bounded metadata signals. + Employs min-max scaling and log-normalization so high citations/years + do not overpower the core relevance score. + """ + if not results: + return [] + + # 1. Gather all raw metadata values across the batch to find min/max + years = [] + citations_list = [] + + for r in results: + meta = r.get("metadata", {}) or {} + + # Extract Year safely handling falsy 0s + y = meta.get("publication_year") + if y is None: + y = meta.get("year") + + try: + if y is not None: + years.append(float(y)) + except (ValueError, TypeError): + pass + + # Extract Citations safely handling falsy 0s + c = meta.get("citations") + if c is None: + c = meta.get("citation_count") + + try: + if c is not None: + citations_list.append(float(c)) + except (ValueError, TypeError): + pass + + # Find the bounds for scaling + min_year = min(years) if years else 0 + max_year = max(years) if years else 0 + + min_cits_log = math.log10(min(citations_list) + 1) if citations_list else 0 + max_cits_log = math.log10(max(citations_list) + 1) if citations_list else 0 + + # 2. Score and Boost each item + trusted_sources = {"allen brain atlas", "gensat", "ebrains"} + + def calculate_boost(r: dict) -> float: + meta = r.get("metadata", {}) or {} + + # -- Year Boost (Max +10%) + year_boost = 0.0 + y = meta.get("publication_year") + if y is None: + y = meta.get("year") + + if y is not None and max_year > min_year: + try: + # 0.0 to 1.0 based on how new it is in this batch + y_norm = (float(y) - min_year) / (max_year - min_year) + year_boost = y_norm * 0.10 + except (ValueError, TypeError): + pass + + # -- Citations Boost (Max +15%) + cit_boost = 0.0 + c = meta.get("citations") + if c is None: + c = meta.get("citation_count") + + if c is not None and max_cits_log > min_cits_log: + try: + # Log normalize out extreme outliers (e.g. 10k citations) + c_log = math.log10(float(c) + 1) + c_norm = (c_log - min_cits_log) / (max_cits_log - min_cits_log) + cit_boost = c_norm * 0.15 + except (ValueError, TypeError): + pass + + # -- Trusted Source Boost (Max +5%) + source_boost = 0.0 + s1 = r.get("datasource_name") + s2 = meta.get("source") + # Ensure we fall back if s1 is None + source_name = str( + s1 if s1 is not None else (s2 if s2 is not None else "") + ).lower() + if any(ts in source_name for ts in trusted_sources): + source_boost = 0.05 + + # Cumulative multiplier (e.g. 1.0 to 1.30) + return 1.0 + year_boost + cit_boost + source_boost + + for item in results: + base_score = float(item.get("_score", 1.0)) + multiplier = calculate_boost(item) + item["_score"] = base_score * multiplier + item["_rerank_multiplier"] = multiplier + + # Re-sort descending based on the new multiplied scores + return sorted(results, key=lambda x: float(x.get("_score", 0)), reverse=True) def tool(args_schema): def decorator(func): func.args_schema = args_schema return func + return decorator @@ -56,7 +201,9 @@ def fuzzy_match(query: str, target: str, threshold: float = 0.8) -> bool: return similarity >= threshold -def find_best_matches(query: str, candidates: List[str], threshold: float = 0.8, max_matches: int = 5) -> List[str]: +def find_best_matches( + query: str, candidates: List[str], threshold: float = 0.8, max_matches: int = 5 +) -> List[str]: matches = [] for candidate in candidates: if fuzzy_match(query, candidate, threshold): @@ -66,7 +213,9 @@ def find_best_matches(query: str, candidates: List[str], threshold: float = 0.8, return [match[0] for match in matches[:max_matches]] -def search_across_all_fields(query: str, all_configs: dict, threshold: float = 0.8) -> List[dict]: +def search_across_all_fields( + query: str, all_configs: dict, threshold: float = 0.8 +) -> List[dict]: """ Keyword search across all available field value lists in all datasources (fuzzy). """ @@ -86,7 +235,9 @@ def search_across_all_fields(query: str, all_configs: dict, threshold: float = 0 ) results.extend(search_results) except Exception as e: - print(f"Error searching {datasource_id} with field {field_name}: {e}") + print( + f"Error searching {datasource_id} with field {field_name}: {e}" + ) continue return results @@ -102,7 +253,7 @@ def global_fuzzy_keyword_search(keywords: Iterable[str], top_k: int = 20) -> Lis all_configs = json.load(fh) out: List[dict] = [] seen = set() - for kw in (keywords or []): + for kw in keywords or []: if not kw: continue results = search_across_all_fields(kw, all_configs, threshold=0.8) @@ -162,7 +313,9 @@ def extract_datasource_info_from_link(link: str) -> tuple: return None, None -async def fetch_dataset_details_async(session, datasource_id: str, dataset_id: str) -> dict: +async def fetch_dataset_details_async( + session, datasource_id: str, dataset_id: str +) -> dict: if not datasource_id or not dataset_id: return {} try: @@ -174,6 +327,7 @@ async def fetch_dataset_details_async(session, datasource_id: str, dataset_id: s print(f" -> Error fetching details for {datasource_id}/{dataset_id}: {e}") return {} + def fetch_dataset_details(datasource_id: str, dataset_id: str) -> dict: if not datasource_id or not dataset_id: return {} @@ -187,74 +341,92 @@ def fetch_dataset_details(datasource_id: str, dataset_id: str) -> dict: return {} -async def enrich_with_dataset_details_async(results: List[dict], top_k: int = 10) -> List[dict]: +async def enrich_with_dataset_details_async( + results: List[dict], top_k: int = 10 +) -> List[dict]: """ Parallel enrichment - fetches dataset details for multiple datasets simultaneously. Instead of: fetch dataset1 -> wait -> fetch dataset2 -> wait -> fetch dataset3 We do: fetch dataset1, dataset2, dataset3 ALL AT ONCE -> wait for all to complete This can reduce enrichment time from 3+ seconds to <1 second for 10 datasets. """ - + async def enrich_single_result(session, result, index): try: # Extract datasource info - link = result.get("primary_link", "") or result.get("metadata", {}).get("url", "") + link = result.get("primary_link", "") or result.get("metadata", {}).get( + "url", "" + ) datasource_id, dataset_id = extract_datasource_info_from_link(link) - + if not datasource_id: metadata = result.get("metadata", {}) or result.get("_source", {}) - source_info = metadata.get("source", "") or metadata.get("datasource", "") + source_info = metadata.get("source", "") or metadata.get( + "datasource", "" + ) if source_info: for name, ds_id in DATASOURCE_NAME_TO_ID.items(): if name.lower() in str(source_info).lower(): datasource_id = ds_id break - + if datasource_id and not dataset_id: metadata = result.get("metadata", {}) or result.get("_source", {}) - dataset_id = metadata.get("id", "") or metadata.get("dataset_id", "") or result.get("_id", "") - + dataset_id = ( + metadata.get("id", "") + or metadata.get("dataset_id", "") + or result.get("_id", "") + ) + # Fetch details if we have both IDs if datasource_id and dataset_id: - print(f" -> Parallel fetching details for {datasource_id}/{dataset_id}") - details = await fetch_dataset_details_async(session, datasource_id, dataset_id) + print( + f" -> Parallel fetching details for {datasource_id}/{dataset_id}" + ) + details = await fetch_dataset_details_async( + session, datasource_id, dataset_id + ) if details: result["detailed_info"] = details result["datasource_id"] = datasource_id - result["datasource_name"] = DATASOURCE_ID_TO_NAME.get(datasource_id, datasource_id) + result["datasource_name"] = DATASOURCE_ID_TO_NAME.get( + datasource_id, datasource_id + ) if "metadata" not in result: result["metadata"] = {} result["metadata"].update(details) - + return result, index - + except Exception as e: print(f" -> Error enriching result {index}: {e}") return result, index # Return original result if enrichment fails - + # Create HTTP session with connection pooling connector = aiohttp.TCPConnector( limit=20, # Total connection pool limit_per_host=10, # Max 10 connections per host - keepalive_timeout=30 + keepalive_timeout=30, ) - + async with aiohttp.ClientSession( - connector=connector, - timeout=aiohttp.ClientTimeout(total=8, connect=2) + connector=connector, timeout=aiohttp.ClientTimeout(total=8, connect=2) ) as session: # Create tasks for ALL results at once - this is the "parallel" part - tasks = [enrich_single_result(session, result, i) for i, result in enumerate(results[:top_k])] - + tasks = [ + enrich_single_result(session, result, i) + for i, result in enumerate(results[:top_k]) + ] + print(f" -> Starting {len(tasks)} parallel enrichment tasks") start_time = asyncio.get_event_loop().time() - + # Execute ALL tasks simultaneously completed_results = await asyncio.gather(*tasks, return_exceptions=True) - + end_time = asyncio.get_event_loop().time() print(f" -> Parallel enrichment completed in {end_time - start_time:.2f}s") - + # Reconstruct results in original order enriched_results = [None] * len(results[:top_k]) for item in completed_results: @@ -263,14 +435,17 @@ async def enrich_single_result(session, result, index): continue result, index = item enriched_results[index] = result - + # Filter out None values and return return [r for r in enriched_results if r is not None] + def enrich_with_dataset_details(results: List[dict], top_k: int = 10) -> List[dict]: enriched_results = [] for i, result in enumerate(results[:top_k]): - link = result.get("primary_link", "") or result.get("metadata", {}).get("url", "") + link = result.get("primary_link", "") or result.get("metadata", {}).get( + "url", "" + ) datasource_id, dataset_id = extract_datasource_info_from_link(link) if not datasource_id: metadata = result.get("metadata", {}) or result.get("_source", {}) @@ -282,13 +457,19 @@ def enrich_with_dataset_details(results: List[dict], top_k: int = 10) -> List[di break if datasource_id and not dataset_id: metadata = result.get("metadata", {}) or result.get("_source", {}) - dataset_id = metadata.get("id", "") or metadata.get("dataset_id", "") or result.get("_id", "") + dataset_id = ( + metadata.get("id", "") + or metadata.get("dataset_id", "") + or result.get("_id", "") + ) if datasource_id and dataset_id: details = fetch_dataset_details(datasource_id, dataset_id) if details: result["detailed_info"] = details result["datasource_id"] = datasource_id - result["datasource_name"] = DATASOURCE_ID_TO_NAME.get(datasource_id, datasource_id) + result["datasource_name"] = DATASOURCE_ID_TO_NAME.get( + datasource_id, datasource_id + ) if "metadata" not in result: result["metadata"] = {} result["metadata"].update(details) @@ -296,7 +477,9 @@ def enrich_with_dataset_details(results: List[dict], top_k: int = 10) -> List[di return enriched_results -async def general_search_async(query: str, top_k: int = 10, enrich_details: bool = True) -> dict: +async def general_search_async( + query: str, top_k: int = 10, enrich_details: bool = True +) -> dict: """Async version of general search with parallel enrichment""" print("--> Executing async general search...") base_url = "https://api.knowledge-space.org/datasets/search" @@ -306,12 +489,22 @@ async def general_search_async(query: str, top_k: int = 10, enrich_details: bool async with session.get(base_url, params=params, timeout=15) as resp: resp.raise_for_status() data = await resp.json() - + results_list = data.get("results", []) normalized_results = [] for i, item in enumerate(results_list): - title = item.get("title") or item.get("name") or item.get("dc.title") or "Dataset" - description = item.get("description") or item.get("abstract") or item.get("summary") or "" + title = ( + item.get("title") + or item.get("name") + or item.get("dc.title") + or "Dataset" + ) + description = ( + item.get("description") + or item.get("abstract") + or item.get("summary") + or "" + ) url = ( item.get("url") or item.get("link") @@ -333,12 +526,16 @@ async def general_search_async(query: str, top_k: int = 10, enrich_details: bool print(f" -> Async general search returned {len(normalized_results)} results") if enrich_details and normalized_results: print(" -> Using parallel async enrichment...") - normalized_results = await enrich_with_dataset_details_async(normalized_results, top_k) + normalized_results = await enrich_with_dataset_details_async( + normalized_results, top_k + ) + normalized_results = rerank_results_using_metadata(normalized_results) return {"combined_results": normalized_results[:top_k]} except Exception as e: print(f" -> Error during async general search: {e}") return {"combined_results": []} + def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> dict: print("--> Executing general search...") base_url = "https://api.knowledge-space.org/datasets/search" @@ -350,8 +547,18 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> results_list = data.get("results", []) normalized_results = [] for i, item in enumerate(results_list): - title = item.get("title") or item.get("name") or item.get("dc.title") or "Dataset" - description = item.get("description") or item.get("abstract") or item.get("summary") or "" + title = ( + item.get("title") + or item.get("name") + or item.get("dc.title") + or "Dataset" + ) + description = ( + item.get("description") + or item.get("abstract") + or item.get("summary") + or "" + ) url = ( item.get("url") or item.get("link") @@ -373,17 +580,24 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> ) print(f" -> General search returned {len(normalized_results)} results") if enrich_details and normalized_results: - print(" -> Enriching results with detailed dataset information (parallel)...") + print( + " -> Enriching results with detailed dataset information (parallel)..." + ) # Use sync enrichment for now - we'll make the whole function async later normalized_results = enrich_with_dataset_details(normalized_results, top_k) + normalized_results = rerank_results_using_metadata(normalized_results) return {"combined_results": normalized_results[:top_k]} except requests.RequestException as e: print(f" -> Error during general search: {e}") return {"combined_results": []} -def _perform_search(data_source_id: str, query: str, filters: dict, all_configs: dict, timeout: int = 10) -> List[dict]: - print(f"--> Searching source '{data_source_id}' with query: '{(query or '*')[:50]}...'") +def _perform_search( + data_source_id: str, query: str, filters: dict, all_configs: dict, timeout: int = 10 +) -> List[dict]: + print( + f"--> Searching source '{data_source_id}' with query: '{(query or '*')[:50]}...'" + ) base_url = "https://knowledge-space.org/entity/source-data-by-entity" valid_filter_map = all_configs.get(data_source_id, {}).get("available_filters", {}) exact_match_filters = [] @@ -413,13 +627,24 @@ def _perform_search(data_source_id: str, query: str, filters: dict, all_configs: resp = requests.get(base_url, params=params, timeout=timeout) resp.raise_for_status() data = resp.json() - hits = (data[0] if isinstance(data, list) and data else data).get("hits", {}).get("hits", []) + hits = ( + (data[0] if isinstance(data, list) and data else data) + .get("hits", {}) + .get("hits", []) + ) print(f" -> Retrieved {len(hits)} raw results") out = [] for hit in hits: src = hit.get("_source", {}) or {} - title = src.get("title") or src.get("name") or src.get("dc.title") or "Dataset" - desc = src.get("description") or src.get("abstract") or src.get("summary") or "" + title = ( + src.get("title") or src.get("name") or src.get("dc.title") or "Dataset" + ) + desc = ( + src.get("description") + or src.get("abstract") + or src.get("summary") + or "" + ) link = ( src.get("url") or src.get("link") @@ -439,6 +664,7 @@ def _perform_search(data_source_id: str, query: str, filters: dict, all_configs: "metadata": src, } ) + out = rerank_results_using_metadata(out) return out except requests.RequestException as e: print(f" -> Error searching {data_source_id}: {e}") @@ -452,13 +678,15 @@ def smart_knowledge_search( data_source: Optional[str] = None, top_k: int = 10, ) -> dict: - q = query or "*" + q = expand_query(query) if query else "*" if filters: config_path = "datasources_config.json" if os.path.exists(config_path): with open(config_path, "r", encoding="utf-8") as fh: all_configs = json.load(fh) - target_id = DATASOURCE_NAME_TO_ID.get(data_source) or (data_source if data_source in all_configs else None) + target_id = DATASOURCE_NAME_TO_ID.get(data_source) or ( + data_source if data_source in all_configs else None + ) if target_id: results = _perform_search(target_id, q, dict(filters), all_configs) return {"combined_results": results[:top_k]} diff --git a/backend/tests/test_metadata_rerank.py b/backend/tests/test_metadata_rerank.py new file mode 100644 index 0000000..ec4fe08 --- /dev/null +++ b/backend/tests/test_metadata_rerank.py @@ -0,0 +1,94 @@ +import pytest +from ks_search_tool import rerank_results_using_metadata + + +def test_rerank_max_bounds(): + """ + Test that the maximum possible boost is exactly +30% + (10% for Year, 15% for Citations, 5% for Trusted Source) + """ + results = [ + # Baseline dataset + { + "_score": 100.0, + "title_guess": "Old Data", + "metadata": {"year": 1990, "citations": 0, "source": "Unknown"}, + }, + # Perfect dataset that should get the max 1.30x multiplier + { + "_score": 100.0, + "title_guess": "Perfect Data", + "metadata": { + "year": 2024, + "citations": 10000, + "source": "Allen Brain Atlas", + }, + }, + ] + + ranked = rerank_results_using_metadata(results) + + # "Perfect Data" should be first due to boost + assert ranked[0]["title_guess"] == "Perfect Data" + + # Baseline should remain exactly 100.0 (no multiplier via min scaling) + assert ranked[1]["_score"] == 100.0 + + # Perfect Data should be exactly 130.0 (1.30x multiplier) + assert ranked[0]["_score"] == 130.0 + assert ranked[0]["_rerank_multiplier"] == 1.30 + + +def test_rerank_log_normalization(): + """ + Test that 10k citations doesn't astronomically outscore 10 citations + thanks to log normalization. + """ + results = [ + {"_score": 100.0, "title_guess": "Zero Cits", "metadata": {"citations": 0}}, + {"_score": 100.0, "title_guess": "Ten Cits", "metadata": {"citations": 10}}, + { + "_score": 100.0, + "title_guess": "Ten Thousand Cits", + "metadata": {"citations": 10000}, + }, + ] + + ranked = rerank_results_using_metadata(results) + + # Highest should still be first + assert ranked[0]["title_guess"] == "Ten Thousand Cits" + + multiplier_high = ranked[0]["_rerank_multiplier"] + multiplier_mid = ranked[1]["_rerank_multiplier"] + multiplier_low = ranked[2]["_rerank_multiplier"] + + # Verify the bounded maximum is respected (max +15% for citations) + assert multiplier_high == 1.15 + assert multiplier_low == 1.00 + + # 10 citations should give a meaningful logarithmic boost (log10(11) / log10(10001)) * 0.15 + # Let's just assert it is meaningfully greater than 1.0 but less than 1.15 + assert 1.0 < multiplier_mid < 1.15 + + +def test_rerank_empty_metadata_handling(): + """ + Test that datasets missing metadata fields do not break the calculation. + """ + results = [ + {"_score": 10.0, "title_guess": "No Meta1"}, + {"_score": 10.0, "title_guess": "No Meta2", "metadata": {}}, + { + "_score": 10.0, + "title_guess": "Garbage Meta", + "metadata": {"year": "unknown", "citations": None}, + }, + ] + + ranked = rerank_results_using_metadata(results) + + # All should retain their base score of 10.0 + for r in ranked: + assert r["_score"] == 10.0 + assert r["_rerank_multiplier"] == 1.0 From 16164bc6fe06a578e205b69dd2e02c323d66c181 Mon Sep 17 00:00:00 2001 From: Zohaib Date: Thu, 19 Mar 2026 02:53:34 +0500 Subject: [PATCH 2/4] fix(retrieval): fix issues from #74 (BQ timeout, inverted sim, etc) --- backend/retrieval.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/backend/retrieval.py b/backend/retrieval.py index 31a0452..30bec1f 100644 --- a/backend/retrieval.py +++ b/backend/retrieval.py @@ -78,7 +78,7 @@ def __init__(self): self.embed_model_name = os.getenv("EMBED_MODEL_NAME", "nomic-ai/nomic-embed-text-v1.5") self.bq_dataset = os.getenv("BQ_DATASET_ID", "ks_metadata") self.bq_table = os.getenv("BQ_TABLE_ID", "docstore") - self.bq_location = os.getenv("BQ_LOCATION","EU") + self.bq_location = os.getenv("BQ_LOCATION","US") try: self.embed_max_tokens = int(os.getenv("EMBED_MAX_TOKENS", "1024")) except Exception: @@ -164,7 +164,7 @@ def _bq_fetch(self, ids: List[str]) -> Dict[str, Dict[str, Any]]: cfg = bigquery.QueryJobConfig( query_parameters=[bigquery.ArrayQueryParameter("ids", "STRING", ids)] ) - rows = self.bq.query(sql, job_config=cfg, location=self.bq_location).result() + rows = self.bq.query(sql, job_config=cfg, location=self.bq_location).result(timeout=10) out: Dict[str, Dict[str, Any]] = {} for r in rows: md = r.metadata_filters @@ -190,7 +190,7 @@ def search( if not self.is_enabled or not query: return [] - qtext = query if (context or {}).get("raw") else query + qtext = query try: vec = self._embed(qtext) @@ -231,10 +231,15 @@ def search( or "" ) try: + # Vertex AI returns L2 distance (lower is better), so we negate it for descending similarity sort similarity = -float(dist) if dist is not None else 0.0 except Exception: similarity = 0.0 + other_links = md.get("other_links", []) + if not isinstance(other_links, list): + other_links = [] + items.append( RetrievedItem( id=dp_id, @@ -242,7 +247,7 @@ def search( content=str(content), metadata=md, primary_link=link, - other_links=[], + other_links=other_links, similarity=similarity, ) ) From 2a12aa5b8e3f95ad1dd3cf51828542860ba73ba0 Mon Sep 17 00:00:00 2001 From: Zohaib Date: Fri, 20 Mar 2026 11:38:36 +0500 Subject: [PATCH 3/4] Revert BQ_LOCATION to EU and remove reranking/query expansion from ks_search_tool.py --- backend/ks_search_tool.py | 151 +------------------------------------- backend/retrieval.py | 4 +- 2 files changed, 6 insertions(+), 149 deletions(-) diff --git a/backend/ks_search_tool.py b/backend/ks_search_tool.py index 95315d3..7996c5e 100644 --- a/backend/ks_search_tool.py +++ b/backend/ks_search_tool.py @@ -8,150 +8,7 @@ import re from urllib.parse import urlparse from difflib import SequenceMatcher -import math - -# --- Query Expansion for Neuroscience Terms --- -QUERY_SYNONYMS = { - "mouse brain": [ - "Rattus norvegicus", - "somatosensory cortex", - "cortex", - "hippocampus", - ], - "memory": ["hippocampus", "synaptic plasticity"], - "hippocampus": ["CA1", "CA3", "dentate gyrus"], - "eeg": ["electroencephalography"], - "fmri": ["functional magnetic resonance imaging", "BOLD"], -} - - -def expand_query(query: str) -> str: - if not query: - return "" - query_lower = query.lower() - expanded = [query_lower] - added_terms = set(expanded) - - for phrase, synonyms in QUERY_SYNONYMS.items(): - if phrase in query_lower: - for syn in synonyms: - if syn not in added_terms: - expanded.append(syn) - added_terms.add(syn) - - for word in query_lower.split(): - if word in QUERY_SYNONYMS: - for syn in QUERY_SYNONYMS[word]: - if syn not in added_terms: - expanded.append(syn) - added_terms.add(syn) - - return " OR ".join([f'"{t}"' if " " in t else t for t in expanded]) - - -# --- Metadata Normalized Reranking --- -def rerank_results_using_metadata(results: List[dict]) -> List[dict]: - """ - Safely boosts semantic/keyword search scores using bounded metadata signals. - Employs min-max scaling and log-normalization so high citations/years - do not overpower the core relevance score. - """ - if not results: - return [] - # 1. Gather all raw metadata values across the batch to find min/max - years = [] - citations_list = [] - - for r in results: - meta = r.get("metadata", {}) or {} - - # Extract Year safely handling falsy 0s - y = meta.get("publication_year") - if y is None: - y = meta.get("year") - - try: - if y is not None: - years.append(float(y)) - except (ValueError, TypeError): - pass - - # Extract Citations safely handling falsy 0s - c = meta.get("citations") - if c is None: - c = meta.get("citation_count") - - try: - if c is not None: - citations_list.append(float(c)) - except (ValueError, TypeError): - pass - - # Find the bounds for scaling - min_year = min(years) if years else 0 - max_year = max(years) if years else 0 - - min_cits_log = math.log10(min(citations_list) + 1) if citations_list else 0 - max_cits_log = math.log10(max(citations_list) + 1) if citations_list else 0 - - # 2. Score and Boost each item - trusted_sources = {"allen brain atlas", "gensat", "ebrains"} - - def calculate_boost(r: dict) -> float: - meta = r.get("metadata", {}) or {} - - # -- Year Boost (Max +10%) - year_boost = 0.0 - y = meta.get("publication_year") - if y is None: - y = meta.get("year") - - if y is not None and max_year > min_year: - try: - # 0.0 to 1.0 based on how new it is in this batch - y_norm = (float(y) - min_year) / (max_year - min_year) - year_boost = y_norm * 0.10 - except (ValueError, TypeError): - pass - - # -- Citations Boost (Max +15%) - cit_boost = 0.0 - c = meta.get("citations") - if c is None: - c = meta.get("citation_count") - - if c is not None and max_cits_log > min_cits_log: - try: - # Log normalize out extreme outliers (e.g. 10k citations) - c_log = math.log10(float(c) + 1) - c_norm = (c_log - min_cits_log) / (max_cits_log - min_cits_log) - cit_boost = c_norm * 0.15 - except (ValueError, TypeError): - pass - - # -- Trusted Source Boost (Max +5%) - source_boost = 0.0 - s1 = r.get("datasource_name") - s2 = meta.get("source") - # Ensure we fall back if s1 is None - source_name = str( - s1 if s1 is not None else (s2 if s2 is not None else "") - ).lower() - if any(ts in source_name for ts in trusted_sources): - source_boost = 0.05 - - # Cumulative multiplier (e.g. 1.0 to 1.30) - return 1.0 + year_boost + cit_boost + source_boost - - for item in results: - base_score = float(item.get("_score", 1.0)) - multiplier = calculate_boost(item) - item["_score"] = base_score * multiplier - item["_rerank_multiplier"] = multiplier - - # Re-sort descending based on the new multiplied scores - return sorted(results, key=lambda x: float(x.get("_score", 0)), reverse=True) def tool(args_schema): @@ -529,7 +386,7 @@ async def general_search_async( normalized_results = await enrich_with_dataset_details_async( normalized_results, top_k ) - normalized_results = rerank_results_using_metadata(normalized_results) + return {"combined_results": normalized_results[:top_k]} except Exception as e: print(f" -> Error during async general search: {e}") @@ -585,7 +442,7 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> ) # Use sync enrichment for now - we'll make the whole function async later normalized_results = enrich_with_dataset_details(normalized_results, top_k) - normalized_results = rerank_results_using_metadata(normalized_results) + return {"combined_results": normalized_results[:top_k]} except requests.RequestException as e: print(f" -> Error during general search: {e}") @@ -664,7 +521,7 @@ def _perform_search( "metadata": src, } ) - out = rerank_results_using_metadata(out) + return out except requests.RequestException as e: print(f" -> Error searching {data_source_id}: {e}") @@ -678,7 +535,7 @@ def smart_knowledge_search( data_source: Optional[str] = None, top_k: int = 10, ) -> dict: - q = expand_query(query) if query else "*" + q = query or "*" if filters: config_path = "datasources_config.json" if os.path.exists(config_path): diff --git a/backend/retrieval.py b/backend/retrieval.py index 30bec1f..a8702ab 100644 --- a/backend/retrieval.py +++ b/backend/retrieval.py @@ -62,7 +62,7 @@ class VertexRetriever(BaseRetriever): - EMBED_MODEL_NAME default: nomic-ai/nomic-embed-text-v1.5 - BQ_DATASET_ID default: ks_metadata - BQ_TABLE_ID default: docstore - - BQ_LOCATION default: US + - BQ_LOCATION default: EU - EMBED_MAX_TOKENS default: 1024 - QUERY_CHAR_LIMIT default: 8000 """ @@ -78,7 +78,7 @@ def __init__(self): self.embed_model_name = os.getenv("EMBED_MODEL_NAME", "nomic-ai/nomic-embed-text-v1.5") self.bq_dataset = os.getenv("BQ_DATASET_ID", "ks_metadata") self.bq_table = os.getenv("BQ_TABLE_ID", "docstore") - self.bq_location = os.getenv("BQ_LOCATION","US") + self.bq_location = os.getenv("BQ_LOCATION","EU") try: self.embed_max_tokens = int(os.getenv("EMBED_MAX_TOKENS", "1024")) except Exception: From 5498209241978ea2977956d26e089b75cb90c1c9 Mon Sep 17 00:00:00 2001 From: Zohaib Date: Fri, 20 Mar 2026 11:56:24 +0500 Subject: [PATCH 4/4] Clean up: remove test_metadata_rerank.py and replace print() with logging --- backend/ks_search_tool.py | 46 +++++++------ backend/tests/test_metadata_rerank.py | 94 --------------------------- 2 files changed, 27 insertions(+), 113 deletions(-) delete mode 100644 backend/tests/test_metadata_rerank.py diff --git a/backend/ks_search_tool.py b/backend/ks_search_tool.py index 7996c5e..242083e 100644 --- a/backend/ks_search_tool.py +++ b/backend/ks_search_tool.py @@ -1,6 +1,14 @@ # ks_search_tool.py import os import json +import logging + +logger = logging.getLogger("ks_search_tool") +logger.setLevel(logging.INFO) +if not logger.handlers: + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) + logger.addHandler(_h) import requests import asyncio import aiohttp @@ -92,7 +100,7 @@ def search_across_all_fields( ) results.extend(search_results) except Exception as e: - print( + logger.info( f"Error searching {datasource_id} with field {field_name}: {e}" ) continue @@ -181,7 +189,7 @@ async def fetch_dataset_details_async( resp.raise_for_status() return await resp.json() except Exception as e: - print(f" -> Error fetching details for {datasource_id}/{dataset_id}: {e}") + logger.error(f" -> Error fetching details for {datasource_id}/{dataset_id}: {e}") return {} @@ -194,7 +202,7 @@ def fetch_dataset_details(datasource_id: str, dataset_id: str) -> dict: resp.raise_for_status() return resp.json() except Exception as e: - print(f" -> Error fetching details for {datasource_id}/{dataset_id}: {e}") + logger.error(f" -> Error fetching details for {datasource_id}/{dataset_id}: {e}") return {} @@ -237,7 +245,7 @@ async def enrich_single_result(session, result, index): # Fetch details if we have both IDs if datasource_id and dataset_id: - print( + logger.info( f" -> Parallel fetching details for {datasource_id}/{dataset_id}" ) details = await fetch_dataset_details_async( @@ -256,7 +264,7 @@ async def enrich_single_result(session, result, index): return result, index except Exception as e: - print(f" -> Error enriching result {index}: {e}") + logger.error(f" -> Error enriching result {index}: {e}") return result, index # Return original result if enrichment fails # Create HTTP session with connection pooling @@ -275,20 +283,20 @@ async def enrich_single_result(session, result, index): for i, result in enumerate(results[:top_k]) ] - print(f" -> Starting {len(tasks)} parallel enrichment tasks") + logger.info(f" -> Starting {len(tasks)} parallel enrichment tasks") start_time = asyncio.get_event_loop().time() # Execute ALL tasks simultaneously completed_results = await asyncio.gather(*tasks, return_exceptions=True) end_time = asyncio.get_event_loop().time() - print(f" -> Parallel enrichment completed in {end_time - start_time:.2f}s") + logger.info(f" -> Parallel enrichment completed in {end_time - start_time:.2f}s") # Reconstruct results in original order enriched_results = [None] * len(results[:top_k]) for item in completed_results: if isinstance(item, Exception): - print(f" -> Task failed: {item}") + logger.info(f" -> Task failed: {item}") continue result, index = item enriched_results[index] = result @@ -338,7 +346,7 @@ async def general_search_async( query: str, top_k: int = 10, enrich_details: bool = True ) -> dict: """Async version of general search with parallel enrichment""" - print("--> Executing async general search...") + logger.info("--> Executing async general search...") base_url = "https://api.knowledge-space.org/datasets/search" params = {"q": query or "*", "per_page": min(top_k * 2, 50)} try: @@ -380,21 +388,21 @@ async def general_search_async( "metadata": item, } ) - print(f" -> Async general search returned {len(normalized_results)} results") + logger.info(f" -> Async general search returned {len(normalized_results)} results") if enrich_details and normalized_results: - print(" -> Using parallel async enrichment...") + logger.info(" -> Using parallel async enrichment...") normalized_results = await enrich_with_dataset_details_async( normalized_results, top_k ) return {"combined_results": normalized_results[:top_k]} except Exception as e: - print(f" -> Error during async general search: {e}") + logger.error(f" -> Error during async general search: {e}") return {"combined_results": []} def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> dict: - print("--> Executing general search...") + logger.info("--> Executing general search...") base_url = "https://api.knowledge-space.org/datasets/search" params = {"q": query or "*", "per_page": min(top_k * 2, 50)} try: @@ -435,9 +443,9 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> "metadata": item, } ) - print(f" -> General search returned {len(normalized_results)} results") + logger.info(f" -> General search returned {len(normalized_results)} results") if enrich_details and normalized_results: - print( + logger.info( " -> Enriching results with detailed dataset information (parallel)..." ) # Use sync enrichment for now - we'll make the whole function async later @@ -445,14 +453,14 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> return {"combined_results": normalized_results[:top_k]} except requests.RequestException as e: - print(f" -> Error during general search: {e}") + logger.error(f" -> Error during general search: {e}") return {"combined_results": []} def _perform_search( data_source_id: str, query: str, filters: dict, all_configs: dict, timeout: int = 10 ) -> List[dict]: - print( + logger.info( f"--> Searching source '{data_source_id}' with query: '{(query or '*')[:50]}...'" ) base_url = "https://knowledge-space.org/entity/source-data-by-entity" @@ -489,7 +497,7 @@ def _perform_search( .get("hits", {}) .get("hits", []) ) - print(f" -> Retrieved {len(hits)} raw results") + logger.info(f" -> Retrieved {len(hits)} raw results") out = [] for hit in hits: src = hit.get("_source", {}) or {} @@ -524,7 +532,7 @@ def _perform_search( return out except requests.RequestException as e: - print(f" -> Error searching {data_source_id}: {e}") + logger.error(f" -> Error searching {data_source_id}: {e}") return [] diff --git a/backend/tests/test_metadata_rerank.py b/backend/tests/test_metadata_rerank.py deleted file mode 100644 index ec4fe08..0000000 --- a/backend/tests/test_metadata_rerank.py +++ /dev/null @@ -1,94 +0,0 @@ -import pytest -from ks_search_tool import rerank_results_using_metadata - - -def test_rerank_max_bounds(): - """ - Test that the maximum possible boost is exactly +30% - (10% for Year, 15% for Citations, 5% for Trusted Source) - """ - results = [ - # Baseline dataset - { - "_score": 100.0, - "title_guess": "Old Data", - "metadata": {"year": 1990, "citations": 0, "source": "Unknown"}, - }, - # Perfect dataset that should get the max 1.30x multiplier - { - "_score": 100.0, - "title_guess": "Perfect Data", - "metadata": { - "year": 2024, - "citations": 10000, - "source": "Allen Brain Atlas", - }, - }, - ] - - ranked = rerank_results_using_metadata(results) - - # "Perfect Data" should be first due to boost - assert ranked[0]["title_guess"] == "Perfect Data" - - # Baseline should remain exactly 100.0 (no multiplier via min scaling) - assert ranked[1]["_score"] == 100.0 - - # Perfect Data should be exactly 130.0 (1.30x multiplier) - assert ranked[0]["_score"] == 130.0 - assert ranked[0]["_rerank_multiplier"] == 1.30 - - -def test_rerank_log_normalization(): - """ - Test that 10k citations doesn't astronomically outscore 10 citations - thanks to log normalization. - """ - results = [ - {"_score": 100.0, "title_guess": "Zero Cits", "metadata": {"citations": 0}}, - {"_score": 100.0, "title_guess": "Ten Cits", "metadata": {"citations": 10}}, - { - "_score": 100.0, - "title_guess": "Ten Thousand Cits", - "metadata": {"citations": 10000}, - }, - ] - - ranked = rerank_results_using_metadata(results) - - # Highest should still be first - assert ranked[0]["title_guess"] == "Ten Thousand Cits" - - multiplier_high = ranked[0]["_rerank_multiplier"] - multiplier_mid = ranked[1]["_rerank_multiplier"] - multiplier_low = ranked[2]["_rerank_multiplier"] - - # Verify the bounded maximum is respected (max +15% for citations) - assert multiplier_high == 1.15 - assert multiplier_low == 1.00 - - # 10 citations should give a meaningful logarithmic boost (log10(11) / log10(10001)) * 0.15 - # Let's just assert it is meaningfully greater than 1.0 but less than 1.15 - assert 1.0 < multiplier_mid < 1.15 - - -def test_rerank_empty_metadata_handling(): - """ - Test that datasets missing metadata fields do not break the calculation. - """ - results = [ - {"_score": 10.0, "title_guess": "No Meta1"}, - {"_score": 10.0, "title_guess": "No Meta2", "metadata": {}}, - { - "_score": 10.0, - "title_guess": "Garbage Meta", - "metadata": {"year": "unknown", "citations": None}, - }, - ] - - ranked = rerank_results_using_metadata(results) - - # All should retain their base score of 10.0 - for r in ranked: - assert r["_score"] == 10.0 - assert r["_rerank_multiplier"] == 1.0