diff --git a/.gitignore b/.gitignore index 3ba759d..716e313 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,6 @@ cython_debug/ #Embedded data files data_processing/embeddings.jsonl + +# Local planning / PR-unrelated (do not push) +.plan/ diff --git a/backend/Dockerfile b/backend/Dockerfile index f246a0e..fd50317 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -5,7 +5,7 @@ RUN pip install uv COPY pyproject.toml ./ COPY .env ./ -COPY backend/service-account.json ./service-account.json +COPY backend/service-account.json.example ./service-account.json RUN UV_HTTP_TIMEOUT=300 uv sync diff --git a/backend/agents.py b/backend/agents.py index ecba861..f4ff13f 100644 --- a/backend/agents.py +++ b/backend/agents.py @@ -10,6 +10,7 @@ from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search from retrieval import get_retriever +from ttfr_estimator import estimate_ttfr # LLM (Gemini) client setup try: @@ -442,6 +443,22 @@ def fuse_results(state: AgentState) -> AgentState: else: combined[doc_id] = {**res, "final_score": res.get("_score", 0) * 0.4} all_sorted = sorted(combined.values(), key=lambda x: x.get("final_score", 0), reverse=True) + for result in all_sorted: + try: + est = estimate_ttfr( + datasource_id=result.get("datasource_id"), + metadata=result.get("metadata") or result.get("detailed_info") or {}, + content=result.get("content") or result.get("description") or "", + ) + result["ttfr_estimate"] = { + "summary": str(est.summary), + "min_days": est.summary.min_days, + "max_days": est.summary.max_days, + "assumptions": est.assumptions, + } + except Exception as e: + result["ttfr_estimate"] = None + print(f"TTFR estimate failed for result: {e}") print(f"Results summary: KS={len(ks_results)}, Vector={len(vector_results)}, Combined={len(all_sorted)}") page_size = 15 return {**state, "all_results": all_sorted, "final_results": all_sorted[:page_size]} @@ -485,8 +502,12 @@ class NeuroscienceAssistant: def __init__(self): self.chat_history: Dict[str, List[str]] = {} self.session_memory: Dict[str, Dict[str, Any]] = {} + self._last_response_metadata: Dict[str, dict] = {} self.graph = self._build_graph() + def get_last_response_metadata(self, session_id: str) -> dict: + return self._last_response_metadata.get(session_id, {}) + def _build_graph(self): workflow = StateGraph(AgentState) workflow.add_node("prepare", extract_keywords_and_rewrite) @@ -503,6 +524,7 @@ def _build_graph(self): def reset_session(self, session_id: str): self.chat_history.pop(session_id, None) self.session_memory.pop(session_id, None) + self._last_response_metadata.pop(session_id, None) async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> str: @@ -537,6 +559,16 @@ async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> "last_text": f"{prev_text}\n\n{text}"[-12000:], }) self.session_memory[session_id] = mem + ttfr_estimates = [] + for r in batch: + te = r.get("ttfr_estimate") + if te and isinstance(te, dict): + ttfr_estimates.append({ + "id": r.get("id") or r.get("_id"), + "title": r.get("title_guess") or r.get("title"), + "ttfr_summary": te.get("summary"), + }) + self._last_response_metadata[session_id] = {"ttfr_estimates": ttfr_estimates} self.chat_history[session_id].extend([f"User: {query}", f"Assistant: {text}"]) if len(self.chat_history[session_id]) > 20: self.chat_history[session_id] = self.chat_history[session_id][-20:] @@ -569,6 +601,16 @@ async def handle_chat(self, session_id: str, query: str, reset: bool = False) -> "intents": final_state.get("intents", [QueryIntent.DATA_DISCOVERY.value]), "last_text": response_text, } + ttfr_estimates = [] + for r in final_state.get("final_results", [])[:15]: + te = r.get("ttfr_estimate") + if te and isinstance(te, dict): + ttfr_estimates.append({ + "id": r.get("id") or r.get("_id"), + "title": r.get("title_guess") or r.get("title"), + "ttfr_summary": te.get("summary"), + }) + self._last_response_metadata[session_id] = {"ttfr_estimates": ttfr_estimates} self.chat_history[session_id].extend([f"User: {query}", f"Assistant: {response_text}"]) if len(self.chat_history[session_id]) > 20: diff --git a/backend/demo_ttfr.py b/backend/demo_ttfr.py new file mode 100644 index 0000000..1a0bd5e --- /dev/null +++ b/backend/demo_ttfr.py @@ -0,0 +1,30 @@ +import json +from ttfr_estimator import estimate_ttfr + +EXAMPLES = [ + {"datasource_id": "scr_005031_openneuro"}, + {"datasource_id": "scr_017612_ebrains"}, + {"datasource_id": "scr_002145_neuromorpho_modelimage"}, + {"content": "fMRI BOLD neuroimaging dataset with multiple subjects"}, + {"datasource_id": "unknown_id", "content": "ion channel database"}, +] + +def main(): + for i, kwargs in enumerate(EXAMPLES, 1): + est = estimate_ttfr(**kwargs) + print(f"Example {i}: {kwargs}") + print(f" Summary: {est.summary}") + print(" Assumptions:") + for a in est.assumptions: + print(f" - {a}") + print("\nJSON format:") + out = { + "summary": str(est.summary), + "phases": {k: str(v) for k, v in est.phases.items()}, + "assumptions": est.assumptions, + } + print(json.dumps(out, indent=2)) + print() + +if __name__ == "__main__": + main() diff --git a/backend/main.py b/backend/main.py index 5f59553..2051375 100644 --- a/backend/main.py +++ b/backend/main.py @@ -113,8 +113,9 @@ async def health(): async def chat_endpoint(msg: ChatMessage): try: start_time = time.time() + session_id = msg.session_id or "default" response_text = await assistant.handle_chat( - session_id=msg.session_id or "default", + session_id=session_id, query=msg.query, reset=bool(msg.reset), ) @@ -125,6 +126,9 @@ async def chat_endpoint(msg: ChatMessage): "timestamp": datetime.utcnow().isoformat(), "reset": bool(msg.reset), } + extra = assistant.get_last_response_metadata(session_id) + if extra: + metadata.update(extra) return ChatResponse(response=response_text, metadata=metadata) except asyncio.TimeoutError: raise HTTPException( diff --git a/backend/retrieval.py b/backend/retrieval.py index 31a0452..11f26ba 100644 --- a/backend/retrieval.py +++ b/backend/retrieval.py @@ -89,10 +89,10 @@ def __init__(self): self.query_char_limit = 8000 # Enable only if everything is present - self.is_enabled = all( + self._is_enabled = all( [self.project_id, self.region, self.index_endpoint_full, self.deployed_id] ) - if not self.is_enabled: + if not self._is_enabled: logger.warning( "Vector search disabled due to incomplete GCP env: " f"project={bool(self.project_id)}, region={bool(self.region)}, " @@ -109,7 +109,7 @@ def __init__(self): self.bq = bigquery.Client(project=self.project_id) except Exception as e: logger.error(f"GCP client initialization failed: {e}") - self.is_enabled = False + self._is_enabled = False return try: @@ -123,7 +123,11 @@ def __init__(self): logger.info(f"Vector search initialized on device={self.device} using {self.embed_model_name}") except Exception as e: logger.error(f"Embedding model initialization failed: {e}") - self.is_enabled = False + self._is_enabled = False + + @property + def is_enabled(self) -> bool: + return getattr(self, "_is_enabled", False) # Embedding def _embed(self, text: str) -> List[float]: diff --git a/backend/service-account.json.example b/backend/service-account.json.example new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/backend/service-account.json.example @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/backend/ttfr_estimator.py b/backend/ttfr_estimator.py new file mode 100644 index 0000000..3ff24ac --- /dev/null +++ b/backend/ttfr_estimator.py @@ -0,0 +1,317 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional + + +class AccessType(Enum): + OPEN = "open" + LOGIN = "login" + APPROVAL = "approval" + + +class FormatType(Enum): + BIDS = "bids" + NWB = "nwb" + CUSTOM = "custom" + STANDARD_IMAGE = "standard_image" + + +class ModalityComplexity(Enum): + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + VERY_HIGH = "very_high" + + +@dataclass +class TimeRange: + min_days: float + max_days: float + + def __str__(self) -> str: + if self.max_days < 1: + min_hours = round(self.min_days * 24) + max_hours = round(self.max_days * 24) + if min_hours == max_hours: + unit = "hour" if min_hours == 1 else "hours" + return f"{min_hours} {unit}" + return f"{min_hours}–{max_hours} hours" + if self.min_days < 1: + min_hours = round(self.min_days * 24) + max_days = round(self.max_days) + hours_unit = "hour" if min_hours == 1 else "hours" + days_unit = "day" if max_days == 1 else "days" + return f"{min_hours} {hours_unit}–{max_days} {days_unit}" + min_days = round(self.min_days) + max_days = round(self.max_days) + if min_days == max_days: + unit = "day" if min_days == 1 else "days" + return f"{min_days} {unit}" + return f"{min_days}–{max_days} days" + + +@dataclass +class TTFREstimate: + summary: TimeRange + phases: Dict[str, TimeRange] + assumptions: List[str] + + +ACCESS_DAYS = { + AccessType.OPEN: (0, 0.5), + AccessType.LOGIN: (0.5, 2), + AccessType.APPROVAL: (2, 14), +} + +FORMAT_DAYS = { + FormatType.BIDS: (0.5, 2), + FormatType.NWB: (0.5, 2), + FormatType.CUSTOM: (1, 5), + FormatType.STANDARD_IMAGE: (0.25, 1), +} + +MODALITY_DAYS = { + ModalityComplexity.LOW: (0.25, 1), + ModalityComplexity.MEDIUM: (0.5, 2), + ModalityComplexity.HIGH: (1, 3), + ModalityComplexity.VERY_HIGH: (2, 5), +} + +DOC_QUALITY_MULTIPLIER = {"high": 1.0, "medium": 1.2, "low": 1.5} + +# Per-datasource defaults: access, format, typical_modality, doc_quality. +# Add or edit entries here when adding new datasources. +DATASOURCE_CONFIG: Dict[str, Dict[str, Any]] = { + "scr_005031_openneuro": { + "access": AccessType.OPEN, + "format": FormatType.BIDS, + "typical_modality": ModalityComplexity.HIGH, + "doc_quality": "high", + }, + "scr_017571_dandi": { + "access": AccessType.OPEN, + "format": FormatType.NWB, + "typical_modality": ModalityComplexity.HIGH, + "doc_quality": "high", + }, + "scr_007271_modeldb_models": { + "access": AccessType.OPEN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.MEDIUM, + "doc_quality": "medium", + }, + "scr_017612_ebrains": { + "access": AccessType.LOGIN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.VERY_HIGH, + "doc_quality": "high", + }, + "scr_003510_cil_images": { + "access": AccessType.OPEN, + "format": FormatType.STANDARD_IMAGE, + "typical_modality": ModalityComplexity.MEDIUM, + "doc_quality": "medium", + }, + "scr_002145_neuromorpho_modelimage": { + "access": AccessType.OPEN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.LOW, + "doc_quality": "high", + }, + "scr_017041_sparc": { + "access": AccessType.OPEN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.HIGH, + "doc_quality": "high", + }, + "scr_002978_aba_expression": { + "access": AccessType.OPEN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.MEDIUM, + "doc_quality": "high", + }, + "scr_005069_brainminds": { + "access": AccessType.APPROVAL, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.VERY_HIGH, + "doc_quality": "medium", + }, + "scr_002721_gensat_geneexpression": { + "access": AccessType.OPEN, + "format": FormatType.STANDARD_IMAGE, + "typical_modality": ModalityComplexity.MEDIUM, + "doc_quality": "medium", + }, + "scr_003105_neurondb_currents": { + "access": AccessType.OPEN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.LOW, + "doc_quality": "medium", + }, + "scr_006131_hba_atlas": { + "access": AccessType.OPEN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.MEDIUM, + "doc_quality": "high", + }, + "scr_014194_icg_ionchannels": { + "access": AccessType.OPEN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.LOW, + "doc_quality": "medium", + }, + "scr_013705_neuroml_models": { + "access": AccessType.OPEN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.MEDIUM, + "doc_quality": "high", + }, + "scr_014306_bbp_cellmorphology": { + "access": AccessType.LOGIN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.MEDIUM, + "doc_quality": "high", + }, + "scr_016433_conp": { + "access": AccessType.OPEN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.HIGH, + "doc_quality": "medium", + }, + "scr_006274_neuroelectro_ephys": { + "access": AccessType.OPEN, + "format": FormatType.CUSTOM, + "typical_modality": ModalityComplexity.MEDIUM, + "doc_quality": "high", + }, +} + +# Classification checks keywords from VERY_HIGH down to LOW; first match wins. +MODALITY_KEYWORDS = { + ModalityComplexity.VERY_HIGH: ["multimodal", "multi-modal", "combined", "integrative"], + ModalityComplexity.HIGH: ["mri", "fmri", "pet", "meg", "eeg", "bold", "neuroimaging"], + ModalityComplexity.MEDIUM: [ + "microscopy", + "image", + "gene expression", + "single cell", + "ephys", + "electrophysiology", + ], + ModalityComplexity.LOW: ["simulated", "model", "morphology", "ion channel", "database"], +} + + +def infer_modality_from_keywords(text: str) -> Optional[ModalityComplexity]: + if not (text and text.strip()): + return None + lower = text.lower() + for level in (ModalityComplexity.VERY_HIGH, ModalityComplexity.HIGH, ModalityComplexity.MEDIUM, ModalityComplexity.LOW): + if any(kw in lower for kw in MODALITY_KEYWORDS[level]): + return level + return None + + +def detect_multimodal(text: Optional[str]) -> bool: + if not text: + return False + lower = text.lower() + keywords = ["multimodal", "multi-modal", "combined", "integrative", " and "] + return any(kw in lower for kw in keywords) + + +def assess_documentation_quality(metadata: Optional[Dict[str, Any]]) -> str: + if not metadata: + return "low" + desc = metadata.get("description") or "" + dc = metadata.get("dc") + if isinstance(dc, dict): + desc = desc or dc.get("description") or "" + links = metadata.get("documentation_url") or metadata.get("url") or metadata.get("identifier") + if isinstance(links, list): + links = len(links) > 0 + has_doc_link = bool(links) + if len(str(desc)) > 200 and has_doc_link: + return "high" + if len(str(desc)) > 50 or has_doc_link: + return "medium" + return "low" + + +def estimate_ttfr( + datasource_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + content: Optional[str] = None, +) -> TTFREstimate: + """ + Estimate time to first result for a datasource. + + Parameters + ---------- + datasource_id : str, optional + Known datasource key from DATASOURCE_CONFIG. If present, its access, + format, typical_modality, and doc_quality are used. + metadata : dict, optional + Datasource metadata (e.g. title, description, dc, documentation_url). + Used for documentation quality and to refine modality. + content : str, optional + Free-text description; used to infer modality and multimodality. + + Returns + ------- + TTFREstimate + Summary time range, per-phase breakdown, and assumptions. + + Notes + ----- + If datasource_id is missing or not in DATASOURCE_CONFIG, falls back to + heuristics from metadata and content. Assumptions list explains how the + estimate was derived. + """ + assumptions: List[str] = [] + cfg = DATASOURCE_CONFIG.get(datasource_id or "") if datasource_id else None + + if cfg: + access = cfg["access"] + fmt = cfg["format"] + modality = cfg["typical_modality"] + doc_quality = cfg.get("doc_quality", "medium") + assumptions.append(f"Using datasource config for {datasource_id}") + else: + access = AccessType.OPEN + fmt = FormatType.CUSTOM + modality = ModalityComplexity.MEDIUM + doc_quality = "medium" + assumptions.append("Unknown datasource; using default access OPEN, format CUSTOM, modality MEDIUM") + inferred = infer_modality_from_keywords((content or "") + " " + str(metadata or "")) + if inferred: + modality = inferred + assumptions.append(f"Inferred modality {modality.value} from content/metadata") + if metadata: + doc_quality = assess_documentation_quality(metadata) + assumptions.append(f"Documentation quality: {doc_quality}") + + content_for_multimodal = (content or "") + " " + ((metadata or {}).get("description") or "") + if detect_multimodal(content_for_multimodal.strip() or None): + modality = ModalityComplexity.VERY_HIGH + assumptions.append("Multimodal content detected; using VERY_HIGH modality") + + mult = DOC_QUALITY_MULTIPLIER.get(doc_quality, 1.2) + a_min, a_max = ACCESS_DAYS[access] + p_min, p_max = FORMAT_DAYS[fmt] + m_min, m_max = MODALITY_DAYS[modality] + p_min, p_max = p_min * mult, p_max * mult + m_min, m_max = m_min * mult, m_max * mult + + total_min = a_min + p_min + m_min + total_max = a_max + p_max + m_max + phases = { + "access": TimeRange(a_min, a_max), + "preprocessing": TimeRange(p_min, p_max), + "first_output": TimeRange(m_min, m_max), + } + return TTFREstimate( + summary=TimeRange(total_min, total_max), + phases=phases, + assumptions=assumptions, + ) diff --git a/pyproject.toml b/pyproject.toml index 67d020c..a4fd866 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,10 @@ dependencies = [ "uvicorn>=0.34.3", ] +[tool.pytest.ini_options] +pythonpath = ["backend"] +testpaths = ["tests"] + [project.optional-dependencies] dev = [ "pytest>=7.0.0", diff --git a/scripts/verify_ttfr_m1_m2.sh b/scripts/verify_ttfr_m1_m2.sh new file mode 100644 index 0000000..a5bc00a --- /dev/null +++ b/scripts/verify_ttfr_m1_m2.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# Verify Milestone 1 and 2 locally. No GCP credentials required. +set -e +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +echo "=== Milestone 1: TTFR module and demo ===" +cd "$ROOT/backend" +python3 -c " +from ttfr_estimator import estimate_ttfr +e = estimate_ttfr(datasource_id='scr_005031_openneuro') +assert e.summary.min_days >= 0 +print(' Import and estimate_ttfr: OK') +print(' Sample:', e.summary) +" +python3 demo_ttfr.py | head -20 +echo " Demo: OK (output above)" + +echo "" +echo "=== Milestone 2: Tests in tests/ ===" +cd "$ROOT" +python3 -m pytest tests/ -v --tb=short 2>&1 | tail -25 + +echo "" +echo "=== Done. No GCP credentials required. ===" diff --git a/tests/test_ttfr_estimator.py b/tests/test_ttfr_estimator.py new file mode 100644 index 0000000..887a1cd --- /dev/null +++ b/tests/test_ttfr_estimator.py @@ -0,0 +1,165 @@ +import pytest +from ttfr_estimator import ( + TTFREstimate, + TimeRange, + ModalityComplexity, + estimate_ttfr, + infer_modality_from_keywords, + detect_multimodal, + assess_documentation_quality, +) + + +class TestInferModalityFromKeywords: + def test_empty_string_returns_none(self): + assert infer_modality_from_keywords("") is None + + def test_whitespace_only_returns_none(self): + assert infer_modality_from_keywords(" \n\t ") is None + + def test_very_high_multimodal(self): + assert infer_modality_from_keywords("multimodal imaging study") == ModalityComplexity.VERY_HIGH + + def test_very_high_combined(self): + assert infer_modality_from_keywords("combined fMRI and EEG") == ModalityComplexity.VERY_HIGH + + def test_high_mri(self): + assert infer_modality_from_keywords("MRI dataset") == ModalityComplexity.HIGH + + def test_high_fmri(self): + assert infer_modality_from_keywords("fMRI BOLD") == ModalityComplexity.HIGH + + def test_high_neuroimaging(self): + assert infer_modality_from_keywords("neuroimaging pipeline") == ModalityComplexity.HIGH + + def test_medium_microscopy(self): + assert infer_modality_from_keywords("microscopy images") == ModalityComplexity.MEDIUM + + def test_medium_ephys(self): + assert infer_modality_from_keywords("electrophysiology recording") == ModalityComplexity.MEDIUM + + def test_low_model(self): + assert infer_modality_from_keywords("computational model") == ModalityComplexity.LOW + + def test_low_ion_channel(self): + assert infer_modality_from_keywords("ion channel database") == ModalityComplexity.LOW + + def test_order_very_high_wins_over_high(self): + text = "multimodal imaging with fMRI" + assert infer_modality_from_keywords(text) == ModalityComplexity.VERY_HIGH + + def test_no_keyword_returns_none(self): + assert infer_modality_from_keywords("random text with no modality") is None + + def test_case_insensitive(self): + assert infer_modality_from_keywords("MRI AND EEG") == ModalityComplexity.HIGH + + +class TestDetectMultimodal: + def test_none_returns_false(self): + assert detect_multimodal(None) is False + + def test_empty_string_returns_false(self): + assert detect_multimodal("") is False + + def test_multimodal_returns_true(self): + assert detect_multimodal("multimodal dataset") is True + + def test_multi_hyphen_modal_returns_true(self): + assert detect_multimodal("multi-modal data") is True + + def test_combined_returns_true(self): + assert detect_multimodal("combined approach") is True + + def test_integrative_returns_true(self): + assert detect_multimodal("integrative analysis") is True + + def test_no_match_returns_false(self): + assert detect_multimodal("single modality fMRI only") is False + + +class TestAssessDocumentationQuality: + def test_none_returns_low(self): + assert assess_documentation_quality(None) == "low" + + def test_empty_dict_returns_low(self): + assert assess_documentation_quality({}) == "low" + + def test_long_description_with_link_returns_high(self): + meta = { + "description": "x" * 201, + "documentation_url": "https://example.com/docs", + } + assert assess_documentation_quality(meta) == "high" + + def test_short_description_no_link_returns_low(self): + assert assess_documentation_quality({"description": "short"}) == "low" + + def test_medium_description_returns_medium(self): + meta = {"description": "a" * 51} + assert assess_documentation_quality(meta) == "medium" + + def test_url_only_returns_medium(self): + assert assess_documentation_quality({"url": "https://example.com"}) == "medium" + + def test_dc_description_used(self): + meta = {"dc": {"description": "y" * 51}} + assert assess_documentation_quality(meta) == "medium" + + +class TestTimeRangeStr: + def test_hours_only(self): + tr = TimeRange(0.25, 0.5) + assert "hour" in str(tr) + + def test_singular_day(self): + tr = TimeRange(1.0, 1.0) + assert str(tr) == "1 day" + + def test_singular_hour(self): + tr = TimeRange(1/24, 1/24) + assert str(tr) == "1 hour" + + def test_days_range(self): + tr = TimeRange(2.0, 5.0) + assert str(tr) == "2–5 days" + + def test_mixed_units(self): + tr = TimeRange(0.5, 2.0) + s = str(tr) + assert "hour" in s and "day" in s + + +class TestEstimateTtfr: + def test_known_datasource_uses_config(self): + est = estimate_ttfr(datasource_id="scr_005031_openneuro") + assert isinstance(est, TTFREstimate) + assert "scr_005031_openneuro" in est.assumptions[0] + assert est.summary.min_days >= 0 + assert est.summary.max_days >= est.summary.min_days + assert "access" in est.phases and "preprocessing" in est.phases and "first_output" in est.phases + + def test_known_datasource_ebrains(self): + est = estimate_ttfr(datasource_id="scr_017612_ebrains") + assert "scr_017612_ebrains" in est.assumptions[0] + assert est.summary.min_days >= 2 + + def test_unknown_datasource_infers_from_content(self): + est = estimate_ttfr(content="fMRI BOLD neuroimaging") + assert any("Unknown datasource" in a for a in est.assumptions) + assert any("modality" in a.lower() for a in est.assumptions) + + def test_empty_input_uses_defaults(self): + est = estimate_ttfr() + assert any("Unknown datasource" in a for a in est.assumptions) + assert len(est.phases) == 3 + + def test_metadata_doc_quality_affects_estimate(self): + est_low = estimate_ttfr(metadata={"description": "x"}) + est_high = estimate_ttfr(metadata={"description": "y" * 201, "url": "https://x.com"}) + assert est_high.summary.max_days <= est_low.summary.max_days * 1.6 + + def test_multimodal_content_bumps_to_very_high(self): + est = estimate_ttfr(content="multimodal fMRI and EEG combined") + assert any("Multimodal" in a for a in est.assumptions) + assert est.phases["first_output"].max_days >= 2