diff --git a/backend/ks_search_tool.py b/backend/ks_search_tool.py index 3004a02..242083e 100644 --- a/backend/ks_search_tool.py +++ b/backend/ks_search_tool.py @@ -1,6 +1,14 @@ # ks_search_tool.py import os import json +import logging + +logger = logging.getLogger("ks_search_tool") +logger.setLevel(logging.INFO) +if not logger.handlers: + _h = logging.StreamHandler() + _h.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) + logger.addHandler(_h) import requests import asyncio import aiohttp @@ -10,10 +18,12 @@ from difflib import SequenceMatcher + def tool(args_schema): def decorator(func): func.args_schema = args_schema return func + return decorator @@ -56,7 +66,9 @@ def fuzzy_match(query: str, target: str, threshold: float = 0.8) -> bool: return similarity >= threshold -def find_best_matches(query: str, candidates: List[str], threshold: float = 0.8, max_matches: int = 5) -> List[str]: +def find_best_matches( + query: str, candidates: List[str], threshold: float = 0.8, max_matches: int = 5 +) -> List[str]: matches = [] for candidate in candidates: if fuzzy_match(query, candidate, threshold): @@ -66,7 +78,9 @@ def find_best_matches(query: str, candidates: List[str], threshold: float = 0.8, return [match[0] for match in matches[:max_matches]] -def search_across_all_fields(query: str, all_configs: dict, threshold: float = 0.8) -> List[dict]: +def search_across_all_fields( + query: str, all_configs: dict, threshold: float = 0.8 +) -> List[dict]: """ Keyword search across all available field value lists in all datasources (fuzzy). """ @@ -86,7 +100,9 @@ def search_across_all_fields(query: str, all_configs: dict, threshold: float = 0 ) results.extend(search_results) except Exception as e: - print(f"Error searching {datasource_id} with field {field_name}: {e}") + logger.info( + f"Error searching {datasource_id} with field {field_name}: {e}" + ) continue return results @@ -102,7 +118,7 @@ def global_fuzzy_keyword_search(keywords: Iterable[str], top_k: int = 20) -> Lis all_configs = json.load(fh) out: List[dict] = [] seen = set() - for kw in (keywords or []): + for kw in keywords or []: if not kw: continue results = search_across_all_fields(kw, all_configs, threshold=0.8) @@ -162,7 +178,9 @@ def extract_datasource_info_from_link(link: str) -> tuple: return None, None -async def fetch_dataset_details_async(session, datasource_id: str, dataset_id: str) -> dict: +async def fetch_dataset_details_async( + session, datasource_id: str, dataset_id: str +) -> dict: if not datasource_id or not dataset_id: return {} try: @@ -171,9 +189,10 @@ async def fetch_dataset_details_async(session, datasource_id: str, dataset_id: s resp.raise_for_status() return await resp.json() except Exception as e: - print(f" -> Error fetching details for {datasource_id}/{dataset_id}: {e}") + logger.error(f" -> Error fetching details for {datasource_id}/{dataset_id}: {e}") return {} + def fetch_dataset_details(datasource_id: str, dataset_id: str) -> dict: if not datasource_id or not dataset_id: return {} @@ -183,94 +202,115 @@ def fetch_dataset_details(datasource_id: str, dataset_id: str) -> dict: resp.raise_for_status() return resp.json() except Exception as e: - print(f" -> Error fetching details for {datasource_id}/{dataset_id}: {e}") + logger.error(f" -> Error fetching details for {datasource_id}/{dataset_id}: {e}") return {} -async def enrich_with_dataset_details_async(results: List[dict], top_k: int = 10) -> List[dict]: +async def enrich_with_dataset_details_async( + results: List[dict], top_k: int = 10 +) -> List[dict]: """ Parallel enrichment - fetches dataset details for multiple datasets simultaneously. Instead of: fetch dataset1 -> wait -> fetch dataset2 -> wait -> fetch dataset3 We do: fetch dataset1, dataset2, dataset3 ALL AT ONCE -> wait for all to complete This can reduce enrichment time from 3+ seconds to <1 second for 10 datasets. """ - + async def enrich_single_result(session, result, index): try: # Extract datasource info - link = result.get("primary_link", "") or result.get("metadata", {}).get("url", "") + link = result.get("primary_link", "") or result.get("metadata", {}).get( + "url", "" + ) datasource_id, dataset_id = extract_datasource_info_from_link(link) - + if not datasource_id: metadata = result.get("metadata", {}) or result.get("_source", {}) - source_info = metadata.get("source", "") or metadata.get("datasource", "") + source_info = metadata.get("source", "") or metadata.get( + "datasource", "" + ) if source_info: for name, ds_id in DATASOURCE_NAME_TO_ID.items(): if name.lower() in str(source_info).lower(): datasource_id = ds_id break - + if datasource_id and not dataset_id: metadata = result.get("metadata", {}) or result.get("_source", {}) - dataset_id = metadata.get("id", "") or metadata.get("dataset_id", "") or result.get("_id", "") - + dataset_id = ( + metadata.get("id", "") + or metadata.get("dataset_id", "") + or result.get("_id", "") + ) + # Fetch details if we have both IDs if datasource_id and dataset_id: - print(f" -> Parallel fetching details for {datasource_id}/{dataset_id}") - details = await fetch_dataset_details_async(session, datasource_id, dataset_id) + logger.info( + f" -> Parallel fetching details for {datasource_id}/{dataset_id}" + ) + details = await fetch_dataset_details_async( + session, datasource_id, dataset_id + ) if details: result["detailed_info"] = details result["datasource_id"] = datasource_id - result["datasource_name"] = DATASOURCE_ID_TO_NAME.get(datasource_id, datasource_id) + result["datasource_name"] = DATASOURCE_ID_TO_NAME.get( + datasource_id, datasource_id + ) if "metadata" not in result: result["metadata"] = {} result["metadata"].update(details) - + return result, index - + except Exception as e: - print(f" -> Error enriching result {index}: {e}") + logger.error(f" -> Error enriching result {index}: {e}") return result, index # Return original result if enrichment fails - + # Create HTTP session with connection pooling connector = aiohttp.TCPConnector( limit=20, # Total connection pool limit_per_host=10, # Max 10 connections per host - keepalive_timeout=30 + keepalive_timeout=30, ) - + async with aiohttp.ClientSession( - connector=connector, - timeout=aiohttp.ClientTimeout(total=8, connect=2) + connector=connector, timeout=aiohttp.ClientTimeout(total=8, connect=2) ) as session: # Create tasks for ALL results at once - this is the "parallel" part - tasks = [enrich_single_result(session, result, i) for i, result in enumerate(results[:top_k])] - - print(f" -> Starting {len(tasks)} parallel enrichment tasks") + tasks = [ + enrich_single_result(session, result, i) + for i, result in enumerate(results[:top_k]) + ] + + logger.info(f" -> Starting {len(tasks)} parallel enrichment tasks") start_time = asyncio.get_event_loop().time() - + # Execute ALL tasks simultaneously completed_results = await asyncio.gather(*tasks, return_exceptions=True) - + end_time = asyncio.get_event_loop().time() - print(f" -> Parallel enrichment completed in {end_time - start_time:.2f}s") - + logger.info(f" -> Parallel enrichment completed in {end_time - start_time:.2f}s") + # Reconstruct results in original order enriched_results = [None] * len(results[:top_k]) for item in completed_results: if isinstance(item, Exception): - print(f" -> Task failed: {item}") + logger.info(f" -> Task failed: {item}") continue result, index = item enriched_results[index] = result - + # Filter out None values and return return [r for r in enriched_results if r is not None] + def enrich_with_dataset_details(results: List[dict], top_k: int = 10) -> List[dict]: enriched_results = [] for i, result in enumerate(results[:top_k]): - link = result.get("primary_link", "") or result.get("metadata", {}).get("url", "") + link = result.get("primary_link", "") or result.get("metadata", {}).get( + "url", "" + ) datasource_id, dataset_id = extract_datasource_info_from_link(link) if not datasource_id: metadata = result.get("metadata", {}) or result.get("_source", {}) @@ -282,13 +322,19 @@ def enrich_with_dataset_details(results: List[dict], top_k: int = 10) -> List[di break if datasource_id and not dataset_id: metadata = result.get("metadata", {}) or result.get("_source", {}) - dataset_id = metadata.get("id", "") or metadata.get("dataset_id", "") or result.get("_id", "") + dataset_id = ( + metadata.get("id", "") + or metadata.get("dataset_id", "") + or result.get("_id", "") + ) if datasource_id and dataset_id: details = fetch_dataset_details(datasource_id, dataset_id) if details: result["detailed_info"] = details result["datasource_id"] = datasource_id - result["datasource_name"] = DATASOURCE_ID_TO_NAME.get(datasource_id, datasource_id) + result["datasource_name"] = DATASOURCE_ID_TO_NAME.get( + datasource_id, datasource_id + ) if "metadata" not in result: result["metadata"] = {} result["metadata"].update(details) @@ -296,9 +342,11 @@ def enrich_with_dataset_details(results: List[dict], top_k: int = 10) -> List[di return enriched_results -async def general_search_async(query: str, top_k: int = 10, enrich_details: bool = True) -> dict: +async def general_search_async( + query: str, top_k: int = 10, enrich_details: bool = True +) -> dict: """Async version of general search with parallel enrichment""" - print("--> Executing async general search...") + logger.info("--> Executing async general search...") base_url = "https://api.knowledge-space.org/datasets/search" params = {"q": query or "*", "per_page": min(top_k * 2, 50)} try: @@ -306,12 +354,22 @@ async def general_search_async(query: str, top_k: int = 10, enrich_details: bool async with session.get(base_url, params=params, timeout=15) as resp: resp.raise_for_status() data = await resp.json() - + results_list = data.get("results", []) normalized_results = [] for i, item in enumerate(results_list): - title = item.get("title") or item.get("name") or item.get("dc.title") or "Dataset" - description = item.get("description") or item.get("abstract") or item.get("summary") or "" + title = ( + item.get("title") + or item.get("name") + or item.get("dc.title") + or "Dataset" + ) + description = ( + item.get("description") + or item.get("abstract") + or item.get("summary") + or "" + ) url = ( item.get("url") or item.get("link") @@ -330,17 +388,21 @@ async def general_search_async(query: str, top_k: int = 10, enrich_details: bool "metadata": item, } ) - print(f" -> Async general search returned {len(normalized_results)} results") + logger.info(f" -> Async general search returned {len(normalized_results)} results") if enrich_details and normalized_results: - print(" -> Using parallel async enrichment...") - normalized_results = await enrich_with_dataset_details_async(normalized_results, top_k) + logger.info(" -> Using parallel async enrichment...") + normalized_results = await enrich_with_dataset_details_async( + normalized_results, top_k + ) + return {"combined_results": normalized_results[:top_k]} except Exception as e: - print(f" -> Error during async general search: {e}") + logger.error(f" -> Error during async general search: {e}") return {"combined_results": []} + def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> dict: - print("--> Executing general search...") + logger.info("--> Executing general search...") base_url = "https://api.knowledge-space.org/datasets/search" params = {"q": query or "*", "per_page": min(top_k * 2, 50)} try: @@ -350,8 +412,18 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> results_list = data.get("results", []) normalized_results = [] for i, item in enumerate(results_list): - title = item.get("title") or item.get("name") or item.get("dc.title") or "Dataset" - description = item.get("description") or item.get("abstract") or item.get("summary") or "" + title = ( + item.get("title") + or item.get("name") + or item.get("dc.title") + or "Dataset" + ) + description = ( + item.get("description") + or item.get("abstract") + or item.get("summary") + or "" + ) url = ( item.get("url") or item.get("link") @@ -371,19 +443,26 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> "metadata": item, } ) - print(f" -> General search returned {len(normalized_results)} results") + logger.info(f" -> General search returned {len(normalized_results)} results") if enrich_details and normalized_results: - print(" -> Enriching results with detailed dataset information (parallel)...") + logger.info( + " -> Enriching results with detailed dataset information (parallel)..." + ) # Use sync enrichment for now - we'll make the whole function async later normalized_results = enrich_with_dataset_details(normalized_results, top_k) + return {"combined_results": normalized_results[:top_k]} except requests.RequestException as e: - print(f" -> Error during general search: {e}") + logger.error(f" -> Error during general search: {e}") return {"combined_results": []} -def _perform_search(data_source_id: str, query: str, filters: dict, all_configs: dict, timeout: int = 10) -> List[dict]: - print(f"--> Searching source '{data_source_id}' with query: '{(query or '*')[:50]}...'") +def _perform_search( + data_source_id: str, query: str, filters: dict, all_configs: dict, timeout: int = 10 +) -> List[dict]: + logger.info( + f"--> Searching source '{data_source_id}' with query: '{(query or '*')[:50]}...'" + ) base_url = "https://knowledge-space.org/entity/source-data-by-entity" valid_filter_map = all_configs.get(data_source_id, {}).get("available_filters", {}) exact_match_filters = [] @@ -413,13 +492,24 @@ def _perform_search(data_source_id: str, query: str, filters: dict, all_configs: resp = requests.get(base_url, params=params, timeout=timeout) resp.raise_for_status() data = resp.json() - hits = (data[0] if isinstance(data, list) and data else data).get("hits", {}).get("hits", []) - print(f" -> Retrieved {len(hits)} raw results") + hits = ( + (data[0] if isinstance(data, list) and data else data) + .get("hits", {}) + .get("hits", []) + ) + logger.info(f" -> Retrieved {len(hits)} raw results") out = [] for hit in hits: src = hit.get("_source", {}) or {} - title = src.get("title") or src.get("name") or src.get("dc.title") or "Dataset" - desc = src.get("description") or src.get("abstract") or src.get("summary") or "" + title = ( + src.get("title") or src.get("name") or src.get("dc.title") or "Dataset" + ) + desc = ( + src.get("description") + or src.get("abstract") + or src.get("summary") + or "" + ) link = ( src.get("url") or src.get("link") @@ -439,9 +529,10 @@ def _perform_search(data_source_id: str, query: str, filters: dict, all_configs: "metadata": src, } ) + return out except requests.RequestException as e: - print(f" -> Error searching {data_source_id}: {e}") + logger.error(f" -> Error searching {data_source_id}: {e}") return [] @@ -458,7 +549,9 @@ def smart_knowledge_search( if os.path.exists(config_path): with open(config_path, "r", encoding="utf-8") as fh: all_configs = json.load(fh) - target_id = DATASOURCE_NAME_TO_ID.get(data_source) or (data_source if data_source in all_configs else None) + target_id = DATASOURCE_NAME_TO_ID.get(data_source) or ( + data_source if data_source in all_configs else None + ) if target_id: results = _perform_search(target_id, q, dict(filters), all_configs) return {"combined_results": results[:top_k]} diff --git a/backend/retrieval.py b/backend/retrieval.py index 31a0452..a8702ab 100644 --- a/backend/retrieval.py +++ b/backend/retrieval.py @@ -62,7 +62,7 @@ class VertexRetriever(BaseRetriever): - EMBED_MODEL_NAME default: nomic-ai/nomic-embed-text-v1.5 - BQ_DATASET_ID default: ks_metadata - BQ_TABLE_ID default: docstore - - BQ_LOCATION default: US + - BQ_LOCATION default: EU - EMBED_MAX_TOKENS default: 1024 - QUERY_CHAR_LIMIT default: 8000 """ @@ -164,7 +164,7 @@ def _bq_fetch(self, ids: List[str]) -> Dict[str, Dict[str, Any]]: cfg = bigquery.QueryJobConfig( query_parameters=[bigquery.ArrayQueryParameter("ids", "STRING", ids)] ) - rows = self.bq.query(sql, job_config=cfg, location=self.bq_location).result() + rows = self.bq.query(sql, job_config=cfg, location=self.bq_location).result(timeout=10) out: Dict[str, Dict[str, Any]] = {} for r in rows: md = r.metadata_filters @@ -190,7 +190,7 @@ def search( if not self.is_enabled or not query: return [] - qtext = query if (context or {}).get("raw") else query + qtext = query try: vec = self._embed(qtext) @@ -231,10 +231,15 @@ def search( or "" ) try: + # Vertex AI returns L2 distance (lower is better), so we negate it for descending similarity sort similarity = -float(dist) if dist is not None else 0.0 except Exception: similarity = 0.0 + other_links = md.get("other_links", []) + if not isinstance(other_links, list): + other_links = [] + items.append( RetrievedItem( id=dp_id, @@ -242,7 +247,7 @@ def search( content=str(content), metadata=md, primary_link=link, - other_links=[], + other_links=other_links, similarity=similarity, ) )