Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,6 @@ cython_debug/

#Embedded data files
data_processing/embeddings.jsonl

# Local planning / PR-unrelated (do not push)
.plan/
2 changes: 1 addition & 1 deletion backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ RUN pip install uv

COPY pyproject.toml ./
COPY .env ./
COPY backend/service-account.json ./service-account.json
COPY backend/service-account.json.example ./service-account.json


RUN UV_HTTP_TIMEOUT=300 uv sync
Expand Down
42 changes: 42 additions & 0 deletions backend/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search
from retrieval import get_retriever
from ttfr_estimator import estimate_ttfr

# LLM (Gemini) client setup
try:
Expand Down Expand Up @@ -442,6 +443,22 @@ def fuse_results(state: AgentState) -> AgentState:
else:
combined[doc_id] = {**res, "final_score": res.get("_score", 0) * 0.4}
all_sorted = sorted(combined.values(), key=lambda x: x.get("final_score", 0), reverse=True)
for result in all_sorted:
try:
est = estimate_ttfr(
datasource_id=result.get("datasource_id"),
metadata=result.get("metadata") or result.get("detailed_info") or {},
content=result.get("content") or result.get("description") or "",
)
result["ttfr_estimate"] = {
"summary": str(est.summary),
"min_days": est.summary.min_days,
"max_days": est.summary.max_days,
"assumptions": est.assumptions,
}
except Exception as e:
result["ttfr_estimate"] = None
print(f"TTFR estimate failed for result: {e}")
print(f"Results summary: KS={len(ks_results)}, Vector={len(vector_results)}, Combined={len(all_sorted)}")
page_size = 15
return {**state, "all_results": all_sorted, "final_results": all_sorted[:page_size]}
Expand Down Expand Up @@ -485,8 +502,12 @@ class NeuroscienceAssistant:
def __init__(self):
self.chat_history: Dict[str, List[str]] = {}
self.session_memory: Dict[str, Dict[str, Any]] = {}
self._last_response_metadata: Dict[str, dict] = {}
self.graph = self._build_graph()

def get_last_response_metadata(self, session_id: str) -> dict:
return self._last_response_metadata.get(session_id, {})

def _build_graph(self):
workflow = StateGraph(AgentState)
workflow.add_node("prepare", extract_keywords_and_rewrite)
Expand All @@ -503,6 +524,7 @@ def _build_graph(self):
def reset_session(self, session_id: str):
self.chat_history.pop(session_id, None)
self.session_memory.pop(session_id, None)
self._last_response_metadata.pop(session_id, None)


async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> str:
Expand Down Expand Up @@ -537,6 +559,16 @@ async def handle_chat(self, session_id: str, query: str, reset: bool = False) ->
"last_text": f"{prev_text}\n\n{text}"[-12000:],
})
self.session_memory[session_id] = mem
ttfr_estimates = []
for r in batch:
te = r.get("ttfr_estimate")
if te and isinstance(te, dict):
ttfr_estimates.append({
"id": r.get("id") or r.get("_id"),
"title": r.get("title_guess") or r.get("title"),
"ttfr_summary": te.get("summary"),
})
self._last_response_metadata[session_id] = {"ttfr_estimates": ttfr_estimates}
self.chat_history[session_id].extend([f"User: {query}", f"Assistant: {text}"])
if len(self.chat_history[session_id]) > 20:
self.chat_history[session_id] = self.chat_history[session_id][-20:]
Expand Down Expand Up @@ -569,6 +601,16 @@ async def handle_chat(self, session_id: str, query: str, reset: bool = False) ->
"intents": final_state.get("intents", [QueryIntent.DATA_DISCOVERY.value]),
"last_text": response_text,
}
ttfr_estimates = []
for r in final_state.get("final_results", [])[:15]:
te = r.get("ttfr_estimate")
if te and isinstance(te, dict):
ttfr_estimates.append({
"id": r.get("id") or r.get("_id"),
"title": r.get("title_guess") or r.get("title"),
"ttfr_summary": te.get("summary"),
})
self._last_response_metadata[session_id] = {"ttfr_estimates": ttfr_estimates}

self.chat_history[session_id].extend([f"User: {query}", f"Assistant: {response_text}"])
if len(self.chat_history[session_id]) > 20:
Expand Down
30 changes: 30 additions & 0 deletions backend/demo_ttfr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import json
from ttfr_estimator import estimate_ttfr

EXAMPLES = [
{"datasource_id": "scr_005031_openneuro"},
{"datasource_id": "scr_017612_ebrains"},
{"datasource_id": "scr_002145_neuromorpho_modelimage"},
{"content": "fMRI BOLD neuroimaging dataset with multiple subjects"},
{"datasource_id": "unknown_id", "content": "ion channel database"},
]

def main():
for i, kwargs in enumerate(EXAMPLES, 1):
est = estimate_ttfr(**kwargs)
print(f"Example {i}: {kwargs}")
print(f" Summary: {est.summary}")
print(" Assumptions:")
for a in est.assumptions:
print(f" - {a}")
print("\nJSON format:")
out = {
"summary": str(est.summary),
"phases": {k: str(v) for k, v in est.phases.items()},
"assumptions": est.assumptions,
}
print(json.dumps(out, indent=2))
print()

if __name__ == "__main__":
main()
6 changes: 5 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,9 @@ async def health():
async def chat_endpoint(msg: ChatMessage):
try:
start_time = time.time()
session_id = msg.session_id or "default"
response_text = await assistant.handle_chat(
session_id=msg.session_id or "default",
session_id=session_id,
query=msg.query,
reset=bool(msg.reset),
)
Expand All @@ -125,6 +126,9 @@ async def chat_endpoint(msg: ChatMessage):
"timestamp": datetime.utcnow().isoformat(),
"reset": bool(msg.reset),
}
extra = assistant.get_last_response_metadata(session_id)
if extra:
metadata.update(extra)
return ChatResponse(response=response_text, metadata=metadata)
except asyncio.TimeoutError:
raise HTTPException(
Expand Down
12 changes: 8 additions & 4 deletions backend/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def __init__(self):
self.query_char_limit = 8000

# Enable only if everything is present
self.is_enabled = all(
self._is_enabled = all(
[self.project_id, self.region, self.index_endpoint_full, self.deployed_id]
)
if not self.is_enabled:
if not self._is_enabled:
logger.warning(
"Vector search disabled due to incomplete GCP env: "
f"project={bool(self.project_id)}, region={bool(self.region)}, "
Expand All @@ -109,7 +109,7 @@ def __init__(self):
self.bq = bigquery.Client(project=self.project_id)
except Exception as e:
logger.error(f"GCP client initialization failed: {e}")
self.is_enabled = False
self._is_enabled = False
return

try:
Expand All @@ -123,7 +123,11 @@ def __init__(self):
logger.info(f"Vector search initialized on device={self.device} using {self.embed_model_name}")
except Exception as e:
logger.error(f"Embedding model initialization failed: {e}")
self.is_enabled = False
self._is_enabled = False

@property
def is_enabled(self) -> bool:
return getattr(self, "_is_enabled", False)

# Embedding
def _embed(self, text: str) -> List[float]:
Expand Down
1 change: 1 addition & 0 deletions backend/service-account.json.example
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Loading