diff --git a/backend/ks_search_tool.py b/backend/ks_search_tool.py index 3004a02..c0bf8fd 100644 --- a/backend/ks_search_tool.py +++ b/backend/ks_search_tool.py @@ -8,12 +8,157 @@ import re from urllib.parse import urlparse from difflib import SequenceMatcher +import math + +# --- Query Expansion for Neuroscience Terms --- +QUERY_SYNONYMS = { + "mouse brain": [ + "Mus musculus", + "somatosensory cortex", + "cortex", + "hippocampus", + ], + "memory": ["hippocampus", "synaptic plasticity"], + "hippocampus": ["CA1", "CA3", "dentate gyrus"], + "eeg": ["electroencephalography"], + "fmri": ["functional magnetic resonance imaging", "BOLD"], +} + + +def expand_query(query: str) -> str: + if not query: + return "" + query_lower = query.lower() + expanded = [query_lower] + added_terms = set(expanded) + + for phrase, synonyms in QUERY_SYNONYMS.items(): + if phrase in query_lower: + for syn in synonyms: + if syn not in added_terms: + expanded.append(syn) + added_terms.add(syn) + + for word in query_lower.split(): + if word in QUERY_SYNONYMS: + for syn in QUERY_SYNONYMS[word]: + if syn not in added_terms: + expanded.append(syn) + added_terms.add(syn) + + return " OR ".join([f'"{t}"' if " " in t else t for t in expanded]) + + +# --- Metadata Normalized Reranking --- +def rerank_results_using_metadata(results: List[dict]) -> List[dict]: + """ + Safely boosts semantic/keyword search scores using bounded metadata signals. + Employs min-max scaling and log-normalization so high citations/years + do not overpower the core relevance score. + """ + if not results: + return [] + + # 1. Gather all raw metadata values across the batch to find min/max + years = [] + citations_list = [] + + for r in results: + meta = r.get("metadata", {}) or {} + + # Extract Year safely handling falsy 0s + y = meta.get("publication_year") + if y is None: + y = meta.get("year") + + try: + if y is not None: + years.append(float(y)) + except (ValueError, TypeError): + pass + + # Extract Citations safely handling falsy 0s + c = meta.get("citations") + if c is None: + c = meta.get("citation_count") + + try: + if c is not None: + citations_list.append(float(c)) + except (ValueError, TypeError): + pass + + # Find the bounds for scaling + min_year = min(years) if years else 0 + max_year = max(years) if years else 0 + + min_cits_log = math.log10(min(citations_list) + 1) if citations_list else 0 + max_cits_log = math.log10(max(citations_list) + 1) if citations_list else 0 + + # 2. Score and Boost each item + trusted_sources = {"allen brain atlas", "gensat", "ebrains"} + + def calculate_boost(r: dict) -> float: + meta = r.get("metadata", {}) or {} + + # -- Year Boost (Max +10%) + year_boost = 0.0 + y = meta.get("publication_year") + if y is None: + y = meta.get("year") + + if y is not None and max_year > min_year: + try: + # 0.0 to 1.0 based on how new it is in this batch + y_norm = (float(y) - min_year) / (max_year - min_year) + year_boost = y_norm * 0.10 + except (ValueError, TypeError): + pass + + # -- Citations Boost (Max +15%) + cit_boost = 0.0 + c = meta.get("citations") + if c is None: + c = meta.get("citation_count") + + if c is not None and max_cits_log > min_cits_log: + try: + # Log normalize out extreme outliers (e.g. 10k citations) + c_log = math.log10(float(c) + 1) + c_norm = (c_log - min_cits_log) / (max_cits_log - min_cits_log) + cit_boost = c_norm * 0.15 + except (ValueError, TypeError): + pass + + # -- Trusted Source Boost (Max +5%) + source_boost = 0.0 + s1 = r.get("datasource_name") + s2 = meta.get("source") + # Ensure we fall back if s1 is None + source_name = str( + s1 if s1 is not None else (s2 if s2 is not None else "") + ).lower() + if any(ts in source_name for ts in trusted_sources): + source_boost = 0.05 + + # Cumulative multiplier (e.g. 1.0 to 1.30) + return 1.0 + year_boost + cit_boost + source_boost + + for item in results: + base_score = float(item.get("_score", 1.0)) + multiplier = calculate_boost(item) + item["_score"] = base_score * multiplier + item["_rerank_multiplier"] = multiplier + + # Re-sort descending based on the new multiplied scores + return sorted(results, key=lambda x: float(x.get("_score", 0)), reverse=True) def tool(args_schema): def decorator(func): func.args_schema = args_schema return func + return decorator @@ -56,7 +201,9 @@ def fuzzy_match(query: str, target: str, threshold: float = 0.8) -> bool: return similarity >= threshold -def find_best_matches(query: str, candidates: List[str], threshold: float = 0.8, max_matches: int = 5) -> List[str]: +def find_best_matches( + query: str, candidates: List[str], threshold: float = 0.8, max_matches: int = 5 +) -> List[str]: matches = [] for candidate in candidates: if fuzzy_match(query, candidate, threshold): @@ -66,7 +213,9 @@ def find_best_matches(query: str, candidates: List[str], threshold: float = 0.8, return [match[0] for match in matches[:max_matches]] -def search_across_all_fields(query: str, all_configs: dict, threshold: float = 0.8) -> List[dict]: +def search_across_all_fields( + query: str, all_configs: dict, threshold: float = 0.8 +) -> List[dict]: """ Keyword search across all available field value lists in all datasources (fuzzy). """ @@ -86,7 +235,9 @@ def search_across_all_fields(query: str, all_configs: dict, threshold: float = 0 ) results.extend(search_results) except Exception as e: - print(f"Error searching {datasource_id} with field {field_name}: {e}") + print( + f"Error searching {datasource_id} with field {field_name}: {e}" + ) continue return results @@ -102,7 +253,7 @@ def global_fuzzy_keyword_search(keywords: Iterable[str], top_k: int = 20) -> Lis all_configs = json.load(fh) out: List[dict] = [] seen = set() - for kw in (keywords or []): + for kw in keywords or []: if not kw: continue results = search_across_all_fields(kw, all_configs, threshold=0.8) @@ -162,7 +313,9 @@ def extract_datasource_info_from_link(link: str) -> tuple: return None, None -async def fetch_dataset_details_async(session, datasource_id: str, dataset_id: str) -> dict: +async def fetch_dataset_details_async( + session, datasource_id: str, dataset_id: str +) -> dict: if not datasource_id or not dataset_id: return {} try: @@ -174,6 +327,7 @@ async def fetch_dataset_details_async(session, datasource_id: str, dataset_id: s print(f" -> Error fetching details for {datasource_id}/{dataset_id}: {e}") return {} + def fetch_dataset_details(datasource_id: str, dataset_id: str) -> dict: if not datasource_id or not dataset_id: return {} @@ -187,74 +341,92 @@ def fetch_dataset_details(datasource_id: str, dataset_id: str) -> dict: return {} -async def enrich_with_dataset_details_async(results: List[dict], top_k: int = 10) -> List[dict]: +async def enrich_with_dataset_details_async( + results: List[dict], top_k: int = 10 +) -> List[dict]: """ Parallel enrichment - fetches dataset details for multiple datasets simultaneously. Instead of: fetch dataset1 -> wait -> fetch dataset2 -> wait -> fetch dataset3 We do: fetch dataset1, dataset2, dataset3 ALL AT ONCE -> wait for all to complete This can reduce enrichment time from 3+ seconds to <1 second for 10 datasets. """ - + async def enrich_single_result(session, result, index): try: # Extract datasource info - link = result.get("primary_link", "") or result.get("metadata", {}).get("url", "") + link = result.get("primary_link", "") or result.get("metadata", {}).get( + "url", "" + ) datasource_id, dataset_id = extract_datasource_info_from_link(link) - + if not datasource_id: metadata = result.get("metadata", {}) or result.get("_source", {}) - source_info = metadata.get("source", "") or metadata.get("datasource", "") + source_info = metadata.get("source", "") or metadata.get( + "datasource", "" + ) if source_info: for name, ds_id in DATASOURCE_NAME_TO_ID.items(): if name.lower() in str(source_info).lower(): datasource_id = ds_id break - + if datasource_id and not dataset_id: metadata = result.get("metadata", {}) or result.get("_source", {}) - dataset_id = metadata.get("id", "") or metadata.get("dataset_id", "") or result.get("_id", "") - + dataset_id = ( + metadata.get("id", "") + or metadata.get("dataset_id", "") + or result.get("_id", "") + ) + # Fetch details if we have both IDs if datasource_id and dataset_id: - print(f" -> Parallel fetching details for {datasource_id}/{dataset_id}") - details = await fetch_dataset_details_async(session, datasource_id, dataset_id) + print( + f" -> Parallel fetching details for {datasource_id}/{dataset_id}" + ) + details = await fetch_dataset_details_async( + session, datasource_id, dataset_id + ) if details: result["detailed_info"] = details result["datasource_id"] = datasource_id - result["datasource_name"] = DATASOURCE_ID_TO_NAME.get(datasource_id, datasource_id) + result["datasource_name"] = DATASOURCE_ID_TO_NAME.get( + datasource_id, datasource_id + ) if "metadata" not in result: result["metadata"] = {} result["metadata"].update(details) - + return result, index - + except Exception as e: print(f" -> Error enriching result {index}: {e}") return result, index # Return original result if enrichment fails - + # Create HTTP session with connection pooling connector = aiohttp.TCPConnector( limit=20, # Total connection pool limit_per_host=10, # Max 10 connections per host - keepalive_timeout=30 + keepalive_timeout=30, ) - + async with aiohttp.ClientSession( - connector=connector, - timeout=aiohttp.ClientTimeout(total=8, connect=2) + connector=connector, timeout=aiohttp.ClientTimeout(total=8, connect=2) ) as session: # Create tasks for ALL results at once - this is the "parallel" part - tasks = [enrich_single_result(session, result, i) for i, result in enumerate(results[:top_k])] - + tasks = [ + enrich_single_result(session, result, i) + for i, result in enumerate(results[:top_k]) + ] + print(f" -> Starting {len(tasks)} parallel enrichment tasks") start_time = asyncio.get_event_loop().time() - + # Execute ALL tasks simultaneously completed_results = await asyncio.gather(*tasks, return_exceptions=True) - + end_time = asyncio.get_event_loop().time() print(f" -> Parallel enrichment completed in {end_time - start_time:.2f}s") - + # Reconstruct results in original order enriched_results = [None] * len(results[:top_k]) for item in completed_results: @@ -263,14 +435,17 @@ async def enrich_single_result(session, result, index): continue result, index = item enriched_results[index] = result - + # Filter out None values and return return [r for r in enriched_results if r is not None] + def enrich_with_dataset_details(results: List[dict], top_k: int = 10) -> List[dict]: enriched_results = [] for i, result in enumerate(results[:top_k]): - link = result.get("primary_link", "") or result.get("metadata", {}).get("url", "") + link = result.get("primary_link", "") or result.get("metadata", {}).get( + "url", "" + ) datasource_id, dataset_id = extract_datasource_info_from_link(link) if not datasource_id: metadata = result.get("metadata", {}) or result.get("_source", {}) @@ -282,13 +457,19 @@ def enrich_with_dataset_details(results: List[dict], top_k: int = 10) -> List[di break if datasource_id and not dataset_id: metadata = result.get("metadata", {}) or result.get("_source", {}) - dataset_id = metadata.get("id", "") or metadata.get("dataset_id", "") or result.get("_id", "") + dataset_id = ( + metadata.get("id", "") + or metadata.get("dataset_id", "") + or result.get("_id", "") + ) if datasource_id and dataset_id: details = fetch_dataset_details(datasource_id, dataset_id) if details: result["detailed_info"] = details result["datasource_id"] = datasource_id - result["datasource_name"] = DATASOURCE_ID_TO_NAME.get(datasource_id, datasource_id) + result["datasource_name"] = DATASOURCE_ID_TO_NAME.get( + datasource_id, datasource_id + ) if "metadata" not in result: result["metadata"] = {} result["metadata"].update(details) @@ -296,22 +477,35 @@ def enrich_with_dataset_details(results: List[dict], top_k: int = 10) -> List[di return enriched_results -async def general_search_async(query: str, top_k: int = 10, enrich_details: bool = True) -> dict: +async def general_search_async( + query: str, top_k: int = 10, enrich_details: bool = True +) -> dict: """Async version of general search with parallel enrichment""" print("--> Executing async general search...") + query = expand_query(query) if query else "*" base_url = "https://api.knowledge-space.org/datasets/search" - params = {"q": query or "*", "per_page": min(top_k * 2, 50)} + params = {"q": query, "per_page": min(top_k * 2, 50)} try: async with aiohttp.ClientSession() as session: async with session.get(base_url, params=params, timeout=15) as resp: resp.raise_for_status() data = await resp.json() - + results_list = data.get("results", []) normalized_results = [] for i, item in enumerate(results_list): - title = item.get("title") or item.get("name") or item.get("dc.title") or "Dataset" - description = item.get("description") or item.get("abstract") or item.get("summary") or "" + title = ( + item.get("title") + or item.get("name") + or item.get("dc.title") + or "Dataset" + ) + description = ( + item.get("description") + or item.get("abstract") + or item.get("summary") + or "" + ) url = ( item.get("url") or item.get("link") @@ -333,16 +527,21 @@ async def general_search_async(query: str, top_k: int = 10, enrich_details: bool print(f" -> Async general search returned {len(normalized_results)} results") if enrich_details and normalized_results: print(" -> Using parallel async enrichment...") - normalized_results = await enrich_with_dataset_details_async(normalized_results, top_k) + normalized_results = await enrich_with_dataset_details_async( + normalized_results, top_k + ) + normalized_results = rerank_results_using_metadata(normalized_results) return {"combined_results": normalized_results[:top_k]} except Exception as e: print(f" -> Error during async general search: {e}") return {"combined_results": []} + def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> dict: print("--> Executing general search...") + query = expand_query(query) if query else "*" base_url = "https://api.knowledge-space.org/datasets/search" - params = {"q": query or "*", "per_page": min(top_k * 2, 50)} + params = {"q": query, "per_page": min(top_k * 2, 50)} try: resp = requests.get(base_url, params=params, timeout=15) resp.raise_for_status() @@ -350,8 +549,18 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> results_list = data.get("results", []) normalized_results = [] for i, item in enumerate(results_list): - title = item.get("title") or item.get("name") or item.get("dc.title") or "Dataset" - description = item.get("description") or item.get("abstract") or item.get("summary") or "" + title = ( + item.get("title") + or item.get("name") + or item.get("dc.title") + or "Dataset" + ) + description = ( + item.get("description") + or item.get("abstract") + or item.get("summary") + or "" + ) url = ( item.get("url") or item.get("link") @@ -373,17 +582,24 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> ) print(f" -> General search returned {len(normalized_results)} results") if enrich_details and normalized_results: - print(" -> Enriching results with detailed dataset information (parallel)...") + print( + " -> Enriching results with detailed dataset information (parallel)..." + ) # Use sync enrichment for now - we'll make the whole function async later normalized_results = enrich_with_dataset_details(normalized_results, top_k) + normalized_results = rerank_results_using_metadata(normalized_results) return {"combined_results": normalized_results[:top_k]} except requests.RequestException as e: print(f" -> Error during general search: {e}") return {"combined_results": []} -def _perform_search(data_source_id: str, query: str, filters: dict, all_configs: dict, timeout: int = 10) -> List[dict]: - print(f"--> Searching source '{data_source_id}' with query: '{(query or '*')[:50]}...'") +def _perform_search( + data_source_id: str, query: str, filters: dict, all_configs: dict, timeout: int = 10 +) -> List[dict]: + print( + f"--> Searching source '{data_source_id}' with query: '{(query or '*')[:50]}...'" + ) base_url = "https://knowledge-space.org/entity/source-data-by-entity" valid_filter_map = all_configs.get(data_source_id, {}).get("available_filters", {}) exact_match_filters = [] @@ -413,13 +629,24 @@ def _perform_search(data_source_id: str, query: str, filters: dict, all_configs: resp = requests.get(base_url, params=params, timeout=timeout) resp.raise_for_status() data = resp.json() - hits = (data[0] if isinstance(data, list) and data else data).get("hits", {}).get("hits", []) + hits = ( + (data[0] if isinstance(data, list) and data else data) + .get("hits", {}) + .get("hits", []) + ) print(f" -> Retrieved {len(hits)} raw results") out = [] for hit in hits: src = hit.get("_source", {}) or {} - title = src.get("title") or src.get("name") or src.get("dc.title") or "Dataset" - desc = src.get("description") or src.get("abstract") or src.get("summary") or "" + title = ( + src.get("title") or src.get("name") or src.get("dc.title") or "Dataset" + ) + desc = ( + src.get("description") + or src.get("abstract") + or src.get("summary") + or "" + ) link = ( src.get("url") or src.get("link") @@ -439,6 +666,7 @@ def _perform_search(data_source_id: str, query: str, filters: dict, all_configs: "metadata": src, } ) + out = rerank_results_using_metadata(out) return out except requests.RequestException as e: print(f" -> Error searching {data_source_id}: {e}") @@ -452,13 +680,15 @@ def smart_knowledge_search( data_source: Optional[str] = None, top_k: int = 10, ) -> dict: - q = query or "*" + q = expand_query(query) if query else "*" if filters: config_path = "datasources_config.json" if os.path.exists(config_path): with open(config_path, "r", encoding="utf-8") as fh: all_configs = json.load(fh) - target_id = DATASOURCE_NAME_TO_ID.get(data_source) or (data_source if data_source in all_configs else None) + target_id = DATASOURCE_NAME_TO_ID.get(data_source) or ( + data_source if data_source in all_configs else None + ) if target_id: results = _perform_search(target_id, q, dict(filters), all_configs) return {"combined_results": results[:top_k]} diff --git a/backend/tests/test_metadata_rerank.py b/backend/tests/test_metadata_rerank.py new file mode 100644 index 0000000..e29d087 --- /dev/null +++ b/backend/tests/test_metadata_rerank.py @@ -0,0 +1,94 @@ +import pytest +from ks_search_tool import rerank_results_using_metadata + + +def test_rerank_max_bounds(): + """ + Test that the maximum possible boost is exactly +30% + (10% for Year, 15% for Citations, 5% for Trusted Source) + """ + results = [ + # Baseline dataset + { + "_score": 100.0, + "title_guess": "Old Data", + "metadata": {"year": 1990, "citations": 0, "source": "Unknown"}, + }, + # Perfect dataset that should get the max 1.30x multiplier + { + "_score": 100.0, + "title_guess": "Perfect Data", + "metadata": { + "year": 2024, + "citations": 10000, + "source": "Allen Brain Atlas", + }, + }, + ] + + ranked = rerank_results_using_metadata(results) + + # "Perfect Data" should be first due to boost + assert ranked[0]["title_guess"] == "Perfect Data" + + # Baseline should remain exactly 100.0 (no multiplier via min scaling) + assert ranked[1]["_score"] == 100.0 + + # Perfect Data should be exactly 130.0 (1.30x multiplier) + assert ranked[0]["_score"] == pytest.approx(130.0) + assert ranked[0]["_rerank_multiplier"] == pytest.approx(1.30) + + +def test_rerank_log_normalization(): + """ + Test that 10k citations doesn't astronomically outscore 10 citations + thanks to log normalization. + """ + results = [ + {"_score": 100.0, "title_guess": "Zero Cits", "metadata": {"citations": 0}}, + {"_score": 100.0, "title_guess": "Ten Cits", "metadata": {"citations": 10}}, + { + "_score": 100.0, + "title_guess": "Ten Thousand Cits", + "metadata": {"citations": 10000}, + }, + ] + + ranked = rerank_results_using_metadata(results) + + # Highest should still be first + assert ranked[0]["title_guess"] == "Ten Thousand Cits" + + multiplier_high = ranked[0]["_rerank_multiplier"] + multiplier_mid = ranked[1]["_rerank_multiplier"] + multiplier_low = ranked[2]["_rerank_multiplier"] + + # Verify the bounded maximum is respected (max +15% for citations) + assert multiplier_high == pytest.approx(1.15) + assert multiplier_low == 1.00 + + # 10 citations should give a meaningful logarithmic boost (log10(11) / log10(10001)) * 0.15 + # Let's just assert it is meaningfully greater than 1.0 but less than 1.15 + assert 1.0 < multiplier_mid < 1.15 + + +def test_rerank_empty_metadata_handling(): + """ + Test that datasets missing metadata fields do not break the calculation. + """ + results = [ + {"_score": 10.0, "title_guess": "No Meta1"}, + {"_score": 10.0, "title_guess": "No Meta2", "metadata": {}}, + { + "_score": 10.0, + "title_guess": "Garbage Meta", + "metadata": {"year": "unknown", "citations": None}, + }, + ] + + ranked = rerank_results_using_metadata(results) + + # All should retain their base score of 10.0 + for r in ranked: + assert r["_score"] == 10.0 + assert r["_rerank_multiplier"] == 1.0