Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions backend/ks_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,22 @@
import aiohttp
from typing import Dict, Optional, Set, Union, List, Any, Iterable
import re
import logging
from urllib.parse import urlparse
from difflib import SequenceMatcher

logger = logging.getLogger("ks_search_tool")
logger.setLevel(logging.INFO)


def tool(args_schema):
def tool(args_schema=None):
def decorator(func):
func.args_schema = args_schema
return func

return decorator


class BaseModel:
pass


class Field:
def _init_(self, description="", default_factory=None):
pass


DATASOURCE_NAME_TO_ID = {
"Allen Brain Atlas Mouse Brain - Expression": "scr_002978_aba_expression",
"GENSAT": "scr_002721_gensat_geneexpression",
Expand Down Expand Up @@ -90,19 +84,17 @@ def search_across_all_fields(
for field_name, field_config in available_filters.items():
field_values = field_config.get("values", [])
matches = find_best_matches(query, field_values, threshold)
if matches:
for match in matches:
try:
search_results = _perform_search(
datasource_id,
query,
{field_name: matches[0]},
{field_name: match},
all_configs,
)
results.extend(search_results)
except Exception as e:
logger.info(
f"Error searching {datasource_id} with field {field_name}: {e}"
)
logger.error(f"Error searching {datasource_id} with field {field_name}: {e}")
continue
return results

Expand All @@ -113,6 +105,7 @@ def global_fuzzy_keyword_search(keywords: Iterable[str], top_k: int = 20) -> Lis
"""
config_path = "datasources_config.json"
if not os.path.exists(config_path):
logger.warning(f"Configuration file missing: {config_path}")
return []
with open(config_path, "r", encoding="utf-8") as fh:
all_configs = json.load(fh)
Expand Down Expand Up @@ -284,19 +277,19 @@ async def enrich_single_result(session, result, index):
]

logger.info(f" -> Starting {len(tasks)} parallel enrichment tasks")
start_time = asyncio.get_event_loop().time()
start_time = asyncio.get_running_loop().time()

# Execute ALL tasks simultaneously
completed_results = await asyncio.gather(*tasks, return_exceptions=True)

end_time = asyncio.get_event_loop().time()
end_time = asyncio.get_running_loop().time()
logger.info(f" -> Parallel enrichment completed in {end_time - start_time:.2f}s")

# Reconstruct results in original order
enriched_results = [None] * len(results[:top_k])
for item in completed_results:
if isinstance(item, Exception):
logger.info(f" -> Task failed: {item}")
logger.error(f" -> Task failed: {item}")
continue
result, index = item
enriched_results[index] = result
Expand Down Expand Up @@ -434,11 +427,11 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) ->
)
normalized_results.append(
{
"id": item.get("id", f"ks{i}"),
"_id": item.get("id", f"general_{i}"),
"_source": item,
"_score": 1.0,
"title_guess": title,
"content": description,
"title": title,
"description": description[:500],
"primary_link": url,
"metadata": item,
}
Expand Down Expand Up @@ -523,8 +516,8 @@ def _perform_search(
"_id": hit.get("_id"),
"_source": src,
"_score": hit.get("_score", 1.0),
"title_guess": title,
"content": desc,
"title": title,
"description": desc[:500],
"primary_link": link,
"metadata": src,
}
Expand All @@ -536,7 +529,7 @@ def _perform_search(
return []


@tool(args_schema=BaseModel)
@tool()
def smart_knowledge_search(
query: Optional[str] = None,
filters: Optional[Union[Dict, Set]] = None,
Expand Down