diff --git a/README.md b/README.md index 6376e50..bcb7c26 100644 --- a/README.md +++ b/README.md @@ -366,6 +366,42 @@ human_comparison.py Used to compare model forecasts against human forecasts. +## Historical-replay mode (benchmarking against human forecasters) + +When benchmarking the pipeline against human forecasters on past questions, +the model must not be allowed to see sources that didn't exist (or contained +different content) at the time the human forecasted. Historical-replay mode +enforces this by reading a single per-question field, `ForecastQuestion.as_of_date`: + +- When `as_of_date` is `None` (default), the pipeline behaves exactly as in + live mode. No code paths change. +- When `as_of_date` is set, the search backend receives `end_date=as_of_date`, + the cache key incorporates the cutoff, post-retrieval filtering drops any + result dated after the cutoff (and any undated result whose date cannot be + cheaply recovered), dashboard URLs are rewritten to the closest Wayback + snapshot at or before the cutoff (or suppressed if none exists), and the + extraction stage fetches from Wayback. Wayback fallback to live is logged + at INFO and recorded in `Document.fetch_strategy`, never silent. + +The LLM "historical roleplay" prompt is *not* automatically enabled by +`as_of_date`; it lives behind a separate `historical_roleplay=True` flag on +`SearchStagePipeline` because its effect on query quality is harder to +predict. Turn it on for the benchmark and off for production. + +What this mode does NOT fix: the LLMs themselves were trained on data that +postdates many of our benchmark questions. Retrieval fairness ≠ model +fairness. The `retrieval_free_baseline_forecast` metric in +`bioscancast/stages/eval_stage/contamination.py` reports how well the LLM +forecasts with no evidence at all; a small gap between that and the full +pipeline is itself evidence of training-data leakage and must be reported +alongside the headline Brier/log scores. + +`filter_caught_contamination_rate` is also exposed by the same module. It +is a **lower bound** on contamination — it only counts post-cutoff results +whose `published_date` is known. Undated results and results whose content +changed post-cutoff are invisible to it. Reports MUST surface this caveat; +the metric's docstring repeats it for the same reason. + --- # Datasets diff --git a/bioscancast/extraction/fetcher.py b/bioscancast/extraction/fetcher.py index d0d71ae..6cc0c01 100644 --- a/bioscancast/extraction/fetcher.py +++ b/bioscancast/extraction/fetcher.py @@ -7,6 +7,8 @@ from curl_cffi import requests as curl_requests +from bioscancast.stages.search_stage.wayback import closest_snapshot_before + from .config import ExtractionConfig logger = logging.getLogger(__name__) @@ -25,6 +27,8 @@ class FetchResult: content_bytes: Optional[bytes] fetched_at: datetime error: Optional[str] + fetch_strategy: str = "live" + snapshot_timestamp: Optional[datetime] = None def _sniff_content_type(content: bytes) -> Optional[str]: @@ -51,6 +55,7 @@ def fetch( url: str, *, config: ExtractionConfig | None = None, + as_of_date: Optional[datetime] = None, ) -> FetchResult: """Fetch a URL and return the result. Never raises on network errors. @@ -58,7 +63,56 @@ def fetch( ExtractionConfig.impersonate) to avoid Cloudflare/JA3-based blocks that reject httpx and requests. The impersonation profile sets a matching User-Agent automatically. + + Historical-replay mode: when ``as_of_date`` is set the function first + asks Wayback for the closest capture at-or-before that date and fetches + the raw snapshot bytes via the ``id_`` modifier. The returned FetchResult + carries ``fetch_strategy="wayback"`` and ``snapshot_timestamp`` set to + the capture time. If no snapshot exists, or the Wayback fetch errors, + we fall back to a live fetch and tag the result + ``fetch_strategy="wayback_fallback_to_live"`` so audit reports can see + the leak. The fallback is logged at INFO — never silent. """ + if as_of_date is not None: + snapshot = closest_snapshot_before(url, as_of_date) + if snapshot is not None: + snapshot_dt, snapshot_url = snapshot + wb_result = _fetch_via_curl( + target_url=snapshot_url, + reported_url=url, + config=config, + ) + if wb_result.error is None and wb_result.content_bytes is not None: + wb_result.fetch_strategy = "wayback" + wb_result.snapshot_timestamp = snapshot_dt + return wb_result + logger.info( + "Wayback fetch failed for %s (snapshot %s, error=%s); " + "falling back to live", + url, snapshot_dt.isoformat(), wb_result.error, + ) + else: + logger.info( + "No Wayback snapshot for %s at-or-before %s; falling back to live", + url, as_of_date.isoformat(), + ) + live_result = _fetch_via_curl(target_url=url, reported_url=url, config=config) + live_result.fetch_strategy = "wayback_fallback_to_live" + return live_result + + return _fetch_via_curl(target_url=url, reported_url=url, config=config) + + +def _fetch_via_curl( + *, + target_url: str, + reported_url: str, + config: ExtractionConfig | None, +) -> FetchResult: + """Issue the actual HTTP GET. ``target_url`` is what we hit (may be a + Wayback ``id_`` URL); ``reported_url`` is what we record in + ``FetchResult.url`` so downstream consumers see the original publisher + URL, not archive.org.""" cfg = config or ExtractionConfig() fetched_at = datetime.now(timezone.utc) @@ -66,7 +120,7 @@ def fetch( # curl_cffi's streaming Response is not a context manager in the # installed version, so we close it explicitly in a finally block. response = curl_requests.get( - url, + target_url, stream=True, timeout=cfg.fetch_timeout_seconds, impersonate=cfg.impersonate, @@ -76,7 +130,7 @@ def fetch( content_length = response.headers.get("content-length") if content_length and int(content_length) > cfg.fetch_max_bytes: return FetchResult( - url=url, + url=reported_url, final_url=str(response.url), status_code=response.status_code, content_type=_normalize_content_type( @@ -95,7 +149,7 @@ def fetch( total += len(chunk) if total > cfg.fetch_max_bytes: return FetchResult( - url=url, + url=reported_url, final_url=str(response.url), status_code=response.status_code, content_type=_normalize_content_type( @@ -118,7 +172,7 @@ def fetch( raw_ct = _sniff_content_type(content_bytes) or raw_ct return FetchResult( - url=url, + url=reported_url, final_url=str(response.url), status_code=response.status_code, content_type=raw_ct, @@ -130,10 +184,10 @@ def fetch( response.close() except Exception as exc: - logger.warning("Fetch failed for %s: %s", url, exc) + logger.warning("Fetch failed for %s: %s", target_url, exc) return FetchResult( - url=url, - final_url=url, + url=reported_url, + final_url=reported_url, status_code=None, content_type=None, content_bytes=None, diff --git a/bioscancast/extraction/pipeline.py b/bioscancast/extraction/pipeline.py index 0ae2d99..47dd1f4 100644 --- a/bioscancast/extraction/pipeline.py +++ b/bioscancast/extraction/pipeline.py @@ -19,10 +19,22 @@ class ExtractionPipeline: - """Orchestrates document fetching, parsing, and chunk normalization.""" + """Orchestrates document fetching, parsing, and chunk normalization. - def __init__(self, *, config: ExtractionConfig | None = None) -> None: + ``as_of_date`` opts the fetcher into Wayback-rewrite mode. See + ``bioscancast.extraction.fetcher.fetch`` for the strategy semantics + (live / wayback / wayback_fallback_to_live). The resulting strategy + and snapshot timestamp are copied onto each Document for audit. + """ + + def __init__( + self, + *, + config: ExtractionConfig | None = None, + as_of_date: Optional[datetime] = None, + ) -> None: self._config = config or ExtractionConfig() + self._as_of_date = as_of_date self._parsers = get_parsers(pdf_max_pages=self._config.pdf_max_pages) # Lazily constructed on first PDF that reaches the refiner step. self._docling_refiner = None @@ -54,7 +66,11 @@ def extract_one(self, filtered_doc: FilteredDocument) -> Document: doc_id = f"doc-{filtered_doc.result_id}" # Step 1: Fetch - fetch_result = fetch(filtered_doc.url, config=self._config) + fetch_result = fetch( + filtered_doc.url, + config=self._config, + as_of_date=self._as_of_date, + ) if fetch_result.error or fetch_result.content_bytes is None: return self._make_failed_document( @@ -169,6 +185,9 @@ def extract_one(self, filtered_doc: FilteredDocument) -> Document: chunks=chunks, extracted_tables=extracted_tables, extracted_dates=extracted_dates, + fetch_strategy=fetch_result.fetch_strategy, + snapshot_timestamp=fetch_result.snapshot_timestamp, + cutoff_applied=self._as_of_date, ) def _get_docling_refiner(self): @@ -212,6 +231,9 @@ def _make_failed_document( error_message=error, http_status=fetch_result.status_code if fetch_result else None, content_type=fetch_result.content_type if fetch_result else None, + fetch_strategy=fetch_result.fetch_strategy if fetch_result else "live", + snapshot_timestamp=fetch_result.snapshot_timestamp if fetch_result else None, + cutoff_applied=self._as_of_date, ) def _build_chunks( diff --git a/bioscancast/filtering/models.py b/bioscancast/filtering/models.py index 058b659..facf320 100644 --- a/bioscancast/filtering/models.py +++ b/bioscancast/filtering/models.py @@ -15,6 +15,13 @@ class ForecastQuestion: pathogen: Optional[str] = None event_type: Optional[str] = None resolution_criteria: Optional[str] = None + # Historical-replay cutoff. When None (default), the pipeline runs in live + # mode and uses datetime.now() everywhere. When set, every cutoff-sensitive + # module (freshness scoring, search backend date filter, cache key, + # post-retrieval filter, dashboard Wayback rewrite, extraction Wayback + # rewrite, optional decomposition roleplay) treats this as "now" so the + # model sees only what a human forecaster could have seen at this moment. + as_of_date: Optional[datetime] = None @dataclass @@ -43,6 +50,15 @@ class SearchResult: retrieval_reason: Optional[str] = None contains_aggregator_forecast: bool = False search_stage_score: float = 0.0 + # Provenance for the date used to evaluate the historical-mode cutoff. + # One of: "backend" (Tavily/Google returned a date), "url_slug", + # "last_modified", "wayback_first_seen", "wayback_snapshot" (for dashboards + # rewritten to Wayback), or None (live mode, or date came from the backend + # in a way that didn't go through the recovery chain). + published_date_source: Optional[str] = None + # The as_of_date that was applied when this result was produced, copied + # off the ForecastQuestion. None in live mode. Useful for post-hoc audits. + cutoff_applied: Optional[datetime] = None @dataclass diff --git a/bioscancast/schemas/document.py b/bioscancast/schemas/document.py index d415995..120ee6a 100644 --- a/bioscancast/schemas/document.py +++ b/bioscancast/schemas/document.py @@ -115,3 +115,13 @@ class Document: extracted_dates: List[str] = field(default_factory=list) """Date strings found anywhere in the document, preserved as-is.""" + + # ---- historical-replay provenance ---- + fetch_strategy: str = "live" + """How the bytes were obtained: 'live', 'wayback', or 'wayback_fallback_to_live'.""" + + snapshot_timestamp: Optional[datetime] = None + """Wayback capture timestamp when fetch_strategy == 'wayback'. None otherwise.""" + + cutoff_applied: Optional[datetime] = None + """The as_of_date that was active when this document was fetched. None in live mode.""" diff --git a/bioscancast/stages/eval_stage/contamination.py b/bioscancast/stages/eval_stage/contamination.py new file mode 100644 index 0000000..a237151 --- /dev/null +++ b/bioscancast/stages/eval_stage/contamination.py @@ -0,0 +1,194 @@ +"""Contamination diagnostics for the human-comparison benchmark. + +These metrics are *not* the same as proving fairness. They are reporting +aids: they let a reviewer see how much of the model's evidence base +demonstrably violated the cutoff, and how much of the model's edge over a +human came from training-data leakage rather than retrieval. + +Two metrics live here: + +* ``filter_caught_contamination_rate`` — a LOWER BOUND on contamination. + Counts SearchResults whose ``published_date`` is later than the cutoff + *and that nonetheless reached the final result list*. After the search + stage's cutoff filter runs this should be ~0; in live-mode benchmark + runs it can be substantial. Undated post-cutoff content is invisible + to this metric — that is the largest source of contamination it cannot + see, and reports MUST say so. + +* ``retrieval_free_baseline_forecast`` — asks the LLM to forecast with no + retrieved evidence, then scores it like any other forecast. A small gap + between this baseline and the full pipeline is itself evidence of + training-data leakage in the model's weights, distinct from leakage in + the retrieval pipeline. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from datetime import datetime +from typing import Iterable, List, Optional + +from bioscancast.filtering.models import ForecastQuestion, SearchResult +from bioscancast.llm.client import LLMClient + +logger = logging.getLogger(__name__) + + +# Phrased verbosely on purpose: a non-coder reading the eval report should +# not confuse "filter-caught" with "absolute". +FILTER_CAUGHT_CONTAMINATION_LOWER_BOUND_LABEL = ( + "filter_caught_contamination_rate " + "(LOWER BOUND — undated post-cutoff content is not counted)" +) + + +@dataclass +class ContaminationCounts: + total: int + post_cutoff_in_final: int + undated_in_final: int + pre_cutoff_in_final: int + + @property + def filter_caught_rate(self) -> float: + """Share of the final results whose dated publication is post-cutoff. + + IMPORTANT: this is a lower bound. It does not count undated results, + and it does not count pages whose content was edited after the + cutoff but whose first-publication date is pre-cutoff. Reports + derived from this metric must surface that caveat. + """ + if self.total == 0: + return 0.0 + return self.post_cutoff_in_final / self.total + + +def filter_caught_contamination_rate( + final_results: Iterable[SearchResult], + as_of: datetime, +) -> ContaminationCounts: + """Count post-cutoff results that nevertheless reached the final list. + + Pass the SearchResult list that the search stage *returned* (i.e. after + its own cutoff filter ran). In a well-behaved historical-replay run the + post_cutoff_in_final count should be 0; in a live-mode run it usually + won't be. + """ + post_cutoff = 0 + undated = 0 + pre_cutoff = 0 + total = 0 + for r in final_results: + total += 1 + if r.published_date is None: + undated += 1 + elif r.published_date > as_of: + post_cutoff += 1 + else: + pre_cutoff += 1 + return ContaminationCounts( + total=total, + post_cutoff_in_final=post_cutoff, + undated_in_final=undated, + pre_cutoff_in_final=pre_cutoff, + ) + + +@dataclass +class BaselineForecast: + question_id: str + options: List[str] + probabilities: List[float] + rationale: Optional[str] = None + + +def retrieval_free_baseline_forecast( + question: ForecastQuestion, + options: List[str], + llm_client: LLMClient, +) -> BaselineForecast: + """Ask the LLM to forecast the question with NO retrieved evidence. + + The gap between this baseline and the full-pipeline forecast quantifies + how much of the model's signal comes from retrieval vs. training-data + knowledge. A small gap on a 2024 question to a 2026-trained LLM is + strong evidence that the LLM already "knew the answer" — which is a + separate fairness problem from retrieval leakage that no amount of + pipeline filtering can fix. Report alongside Brier/log scores, never + in place of them. + """ + prompt = json.dumps( + { + "task": ( + "Forecast the probability of each option for this biosecurity " + "question using ONLY your prior knowledge. Do not assume any " + "additional research has been done. Return strict JSON: " + "{\"probabilities\": [], " + "\"rationale\": \"\"}. Probabilities must sum to 1." + ), + "question": question.text, + "pathogen": question.pathogen, + "region": question.region, + "target_date": ( + question.target_date.isoformat() if question.target_date else None + ), + "as_of_date": ( + question.as_of_date.date().isoformat() if question.as_of_date else None + ), + "options": options, + } + ) + try: + result = llm_client.generate_json(prompt) + except Exception: + logger.exception("Retrieval-free baseline LLM call failed for %s", question.id) + uniform = [1.0 / len(options)] * len(options) + return BaselineForecast( + question_id=question.id, + options=options, + probabilities=uniform, + rationale="LLM call failed; uniform fallback", + ) + + raw_probs = result.get("probabilities") or [] + if not isinstance(raw_probs, list) or len(raw_probs) != len(options): + logger.warning( + "Baseline LLM returned malformed probabilities for %s: %r", + question.id, raw_probs, + ) + uniform = [1.0 / len(options)] * len(options) + return BaselineForecast( + question_id=question.id, + options=options, + probabilities=uniform, + rationale="Malformed LLM output; uniform fallback", + ) + try: + probs = [float(p) for p in raw_probs] + except (TypeError, ValueError): + uniform = [1.0 / len(options)] * len(options) + return BaselineForecast( + question_id=question.id, + options=options, + probabilities=uniform, + rationale="Non-numeric LLM output; uniform fallback", + ) + + total = sum(probs) + if total <= 0: + uniform = [1.0 / len(options)] * len(options) + return BaselineForecast( + question_id=question.id, + options=options, + probabilities=uniform, + rationale="Zero-sum LLM output; uniform fallback", + ) + probs = [p / total for p in probs] + return BaselineForecast( + question_id=question.id, + options=options, + probabilities=probs, + rationale=result.get("rationale"), + ) diff --git a/bioscancast/stages/search_stage/backends/base.py b/bioscancast/stages/search_stage/backends/base.py index ad2ec1f..7fdf32e 100644 --- a/bioscancast/stages/search_stage/backends/base.py +++ b/bioscancast/stages/search_stage/backends/base.py @@ -17,6 +17,20 @@ class RawSearchResult: class SearchBackend(Protocol): - """Interface that all search backends must satisfy.""" + """Interface that all search backends must satisfy. - def search(self, query: str, max_results: int = 10) -> List[RawSearchResult]: ... + ``start_date`` and ``end_date`` are optional YYYY-MM-DD bounds used by + historical-replay mode. Tavily's news endpoint requires the **pair** to be + set together (see ``tavily_backend.py``); passing ``end_date`` alone is + silently ignored. Backends that don't support either should accept and + ignore them — the post-retrieval cutoff filter in the pipeline will still + apply. + """ + + def search( + self, + query: str, + max_results: int = 10, + end_date: Optional[str] = None, + start_date: Optional[str] = None, + ) -> List[RawSearchResult]: ... diff --git a/bioscancast/stages/search_stage/backends/google_cse_backend.py b/bioscancast/stages/search_stage/backends/google_cse_backend.py index dbeb19b..7cbfaa1 100644 --- a/bioscancast/stages/search_stage/backends/google_cse_backend.py +++ b/bioscancast/stages/search_stage/backends/google_cse_backend.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import List +from typing import List, Optional from .base import RawSearchResult @@ -15,9 +15,16 @@ class GoogleCSEBackend: """Stub backend — raises NotImplementedError on use.""" - def search(self, query: str, max_results: int = 10) -> List[RawSearchResult]: + def search( + self, + query: str, + max_results: int = 10, + end_date: Optional[str] = None, + ) -> List[RawSearchResult]: raise NotImplementedError( "GoogleCSEBackend is a stub. Implement using the Google Custom Search " - "JSON API ($5/1k queries after 100/day free tier). See base.py for the " - "SearchBackend protocol." + "JSON API ($5/1k queries after 100/day free tier). When implementing, " + "the YYYY-MM-DD `end_date` argument should be honoured via the CSE " + "`sort=date:r:YYYYMMDD:YYYYMMDD` parameter for historical-replay mode. " + "See base.py for the SearchBackend protocol." ) diff --git a/bioscancast/stages/search_stage/backends/tavily_backend.py b/bioscancast/stages/search_stage/backends/tavily_backend.py index 5f5c8bd..7bce08d 100644 --- a/bioscancast/stages/search_stage/backends/tavily_backend.py +++ b/bioscancast/stages/search_stage/backends/tavily_backend.py @@ -31,17 +31,43 @@ def __init__(self, api_key: Optional[str] = None) -> None: "TAVILY_API_KEY is required. Set it in your environment or pass api_key." ) - def search(self, query: str, max_results: int = 10) -> List[RawSearchResult]: + def search( + self, + query: str, + max_results: int = 10, + end_date: Optional[str] = None, + start_date: Optional[str] = None, + ) -> List[RawSearchResult]: + # Date-window behavior (verified 2026-05-20, see + # ``specs/tavily-investigation-findings.md``): Tavily's news endpoint + # honors ``start_date`` + ``end_date`` only when **both** are passed + # together. Passing ``end_date`` alone is silently ignored and the + # results come back unfiltered. The pipeline is responsible for + # supplying a sensible ``start_date`` alongside any ``end_date``; if + # only ``end_date`` is passed here we drop it rather than send a + # request we know Tavily will misinterpret. The post-retrieval cutoff + # filter in ``SearchStagePipeline`` remains the authoritative defense. from tavily import TavilyClient # lazy import to avoid hard dep at import time client = TavilyClient(api_key=self._api_key) - try: - response = client.search( - query=query, - max_results=max_results, - topic="news", - include_answer=False, + kwargs: dict = { + "query": query, + "max_results": max_results, + "topic": "news", + "include_answer": False, + } + if start_date and end_date: + kwargs["start_date"] = start_date + kwargs["end_date"] = end_date + elif end_date and not start_date: + logger.warning( + "TavilyBackend received end_date=%s without start_date; " + "dropping (Tavily ignores end_date alone). Cutoff filter " + "will still apply post-retrieval.", + end_date, ) + try: + response = client.search(**kwargs) except Exception: logger.exception("Tavily search failed for query: %s", query) return [] diff --git a/bioscancast/stages/search_stage/cache.py b/bioscancast/stages/search_stage/cache.py index 4bb67c0..f2ae522 100644 --- a/bioscancast/stages/search_stage/cache.py +++ b/bioscancast/stages/search_stage/cache.py @@ -38,15 +38,28 @@ def __init__(self, db_path: str = "data/cache/search_cache.sqlite") -> None: self._conn.commit() @staticmethod - def _make_key(backend_name: str, query: str) -> str: - date_bucket = datetime.now(timezone.utc).strftime("%Y-%m-%d") + def _make_key( + backend_name: str, + query: str, + as_of_date: Optional[datetime] = None, + ) -> str: + # In historical-replay mode the bucket is the cutoff date, so that two + # benchmark runs against different cutoffs never share cache entries. + if as_of_date is not None: + date_bucket = as_of_date.strftime("%Y-%m-%d") + else: + date_bucket = datetime.now(timezone.utc).strftime("%Y-%m-%d") raw = f"{backend_name}|{query.strip().lower()}|{date_bucket}" return hashlib.sha256(raw.encode()).hexdigest() def get( - self, backend_name: str, query: str, max_age_hours: int = 24 + self, + backend_name: str, + query: str, + max_age_hours: int = 24, + as_of_date: Optional[datetime] = None, ) -> Optional[List[RawSearchResult]]: - key = self._make_key(backend_name, query) + key = self._make_key(backend_name, query, as_of_date) row = self._conn.execute( "SELECT results_json, created_at FROM search_cache WHERE cache_key = ?", (key,), @@ -63,8 +76,14 @@ def get( items = json.loads(row[0]) return [RawSearchResult(**item) for item in items] - def put(self, backend_name: str, query: str, results: List[RawSearchResult]) -> None: - key = self._make_key(backend_name, query) + def put( + self, + backend_name: str, + query: str, + results: List[RawSearchResult], + as_of_date: Optional[datetime] = None, + ) -> None: + key = self._make_key(backend_name, query, as_of_date) payload = json.dumps( [ { diff --git a/bioscancast/stages/search_stage/dashboard_lookup.py b/bioscancast/stages/search_stage/dashboard_lookup.py index 99aea4f..e3784c3 100644 --- a/bioscancast/stages/search_stage/dashboard_lookup.py +++ b/bioscancast/stages/search_stage/dashboard_lookup.py @@ -1,10 +1,20 @@ """Dashboard lookup — inject known pathogen dashboard URLs as SearchResults. +In live mode this returns the live dashboard URL with a synthetic +``published_date=None`` and freshness=1.0 — a sensible signal that the +dashboard "is current". In historical-replay mode (``question.as_of_date`` +set), live dashboards are dangerous: they return today's case counts even +for a question created in early 2025. We therefore look up the closest +Wayback snapshot at-or-before the cutoff and rewrite the URL; if no +pre-cutoff snapshot exists, we suppress the dashboard entirely rather +than fall back to live. + v1 — flagged for iteration after first benchmark run. """ from __future__ import annotations +import logging import uuid from datetime import datetime, timezone from typing import List @@ -16,14 +26,23 @@ extract_domain, normalize_url, ) +from bioscancast.stages.search_stage.wayback import closest_snapshot_before + +logger = logging.getLogger(__name__) def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: """Generate synthetic SearchResult entries for known pathogen dashboards. - If ``question.pathogen`` (lowercased) matches a key in DASHBOARD_LOOKUP, - returns a SearchResult for each URL with rank=0 and - retrieval_reason="dashboard_lookup". Returns empty list if no match. + Live mode: returns one SearchResult per URL with rank=0 and + retrieval_reason="dashboard_lookup". + + Historical-replay mode (``question.as_of_date`` is not None): for each + URL, looks up the closest Wayback snapshot at-or-before the cutoff and + emits a SearchResult pointing at the snapshot. Dashboards with no + pre-cutoff snapshot are suppressed entirely (NOT fallen-back to live) + because live dashboards return today's counts and would silently + contaminate the benchmark. """ if not question.pathogen: return [] @@ -33,11 +52,32 @@ def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: if not urls: return [] + as_of = question.as_of_date results: list[SearchResult] = [] now = datetime.now(timezone.utc) for url in urls: - domain = extract_domain(url) + if as_of is not None: + snapshot = closest_snapshot_before(url, as_of) + if snapshot is None: + logger.info( + "Suppressing dashboard %s — no Wayback snapshot at-or-before %s", + url, as_of.isoformat(), + ) + continue + snapshot_dt, snapshot_url = snapshot + effective_url = snapshot_url + published_date: datetime | None = snapshot_dt + published_date_source = "wayback_snapshot" + # Keep ``domain`` as the original publisher for tier scoring; + # the URL itself points at archive.org for fetching. + domain = extract_domain(url) + else: + effective_url = url + published_date = None + published_date_source = None + domain = extract_domain(url) + tier_num, domain_score, source_tier = resolve_tier(domain) results.append( @@ -46,13 +86,14 @@ def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: question_id=question.id, query_id=f"dashboard_{question.id}", engine="dashboard", - url=url, - canonical_url=normalize_url(url), + url=effective_url, + canonical_url=normalize_url(effective_url), domain=domain, title=f"Dashboard: {domain}", snippet=f"Known {pathogen_key} monitoring dashboard", rank=0, retrieved_at=now, + published_date=published_date, is_official_domain=(tier_num == 1 and source_tier == "official"), source_tier=source_tier, domain_score=domain_score, @@ -60,6 +101,8 @@ def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: retrieval_reason="dashboard_lookup", contains_aggregator_forecast=is_aggregator_domain(domain), search_stage_score=0.0, # computed later by pipeline + published_date_source=published_date_source, + cutoff_applied=as_of, ) ) diff --git a/bioscancast/stages/search_stage/date_recovery.py b/bioscancast/stages/search_stage/date_recovery.py new file mode 100644 index 0000000..2c851c0 --- /dev/null +++ b/bioscancast/stages/search_stage/date_recovery.py @@ -0,0 +1,124 @@ +"""Recover a plausible publication date for a SearchResult whose backend +didn't supply one. + +Why this exists: in historical-replay mode the pipeline must drop any source +it cannot date, because undated pages are exactly where post-cutoff content +can hide (a page first published before the cutoff but rewritten afterwards +will still report no ``published_date`` from Tavily). Soft-allowing undated +results would silently defeat the benchmark; this module instead tries cheap +external signals before giving up. + +Recovery strategies, cheapest first: + +1. URL slug regex (``/2024/03/15/...`` and ``/2024-03-15/...``) — free, no + network call. Catches most news organisations. +2. ``Last-Modified`` header via HEAD request — off by default. Requires + passing a fetcher callable explicitly; the search stage does not normally + carry an HTTP client. Opt in only when you need the recall. +3. Wayback Machine first-seen — one CDX call per URL. Conservative: "first + archived" is an upper bound on first published, so a pre-cutoff first- + seen is sound evidence the page existed before the cutoff. + +Each function returns ``Optional[datetime]`` and never raises on network or +parse errors (it logs and returns ``None``). +""" + +from __future__ import annotations + +import logging +import re +from datetime import datetime, timezone +from typing import Callable, Optional + +from .wayback import first_seen as _wayback_first_seen + +logger = logging.getLogger(__name__) + +# Matches /YYYY/MM/DD/ and /YYYY/MM/ and /YYYY-MM-DD/ within a URL path. +_URL_DATE_PATTERNS = [ + re.compile(r"/(\d{4})/(\d{1,2})/(\d{1,2})(?:/|$|[?#])"), + re.compile(r"/(\d{4})-(\d{1,2})-(\d{1,2})(?:/|$|[?#])"), + re.compile(r"/(\d{4})/(\d{1,2})(?:/|$|[?#])"), +] + + +def date_from_url_slug(url: str) -> Optional[datetime]: + """Extract a date from common URL slug patterns. Returns midnight UTC of + the matched date, or None if no pattern matches or the date is invalid.""" + for pattern in _URL_DATE_PATTERNS: + m = pattern.search(url) + if not m: + continue + groups = m.groups() + try: + year = int(groups[0]) + month = int(groups[1]) + day = int(groups[2]) if len(groups) >= 3 else 1 + if year < 1990 or year > 2100: + continue # almost certainly not a date + return datetime(year, month, day, tzinfo=timezone.utc) + except (ValueError, IndexError): + continue + return None + + +def date_from_last_modified( + url: str, head_fetcher: Optional[Callable[[str], Optional[str]]] = None +) -> Optional[datetime]: + """Issue a HEAD request and parse the Last-Modified header. + + The caller must pass a ``head_fetcher`` callable that returns the + Last-Modified header string (or None). The search stage does not have a + built-in HTTP client, so this is dependency-injected to avoid an awkward + import of ``curl_cffi`` into the search-stage package. Off by default. + """ + if head_fetcher is None: + return None + try: + header = head_fetcher(url) + except Exception as exc: + logger.warning("HEAD request failed for %s: %s", url, exc) + return None + if not header: + return None + # RFC 7231 format: "Wed, 21 Oct 2015 07:28:00 GMT" + for fmt in ("%a, %d %b %Y %H:%M:%S %Z", "%a, %d %b %Y %H:%M:%S GMT"): + try: + dt = datetime.strptime(header, fmt) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + except ValueError: + continue + return None + + +def date_from_wayback_first_seen(url: str) -> Optional[datetime]: + """Earliest Wayback capture timestamp for ``url`` as an upper bound on + first publication. Returns None on lookup failure.""" + return _wayback_first_seen(url) + + +def recover_published_date( + url: str, + head_fetcher: Optional[Callable[[str], Optional[str]]] = None, + use_wayback: bool = True, +) -> tuple[Optional[datetime], Optional[str]]: + """Try each strategy in order. Returns (date, source_label) where the + label is one of ``"url_slug" | "last_modified" | "wayback_first_seen"`` + on success, or (None, None) when no strategy yielded a date. + """ + dt = date_from_url_slug(url) + if dt is not None: + return dt, "url_slug" + + dt = date_from_last_modified(url, head_fetcher=head_fetcher) + if dt is not None: + return dt, "last_modified" + + if use_wayback: + dt = date_from_wayback_first_seen(url) + if dt is not None: + return dt, "wayback_first_seen" + + return None, None diff --git a/bioscancast/stages/search_stage/pipeline.py b/bioscancast/stages/search_stage/pipeline.py index 7882ab0..3012858 100644 --- a/bioscancast/stages/search_stage/pipeline.py +++ b/bioscancast/stages/search_stage/pipeline.py @@ -8,7 +8,8 @@ import logging import uuid -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone +from email.utils import parsedate_to_datetime from typing import List, Optional from bioscancast.filtering.config import FILTER_CONFIG @@ -17,6 +18,7 @@ from bioscancast.stages.search_stage.backends.base import RawSearchResult, SearchBackend from bioscancast.stages.search_stage.cache import SearchCache from bioscancast.stages.search_stage.dashboard_lookup import lookup_dashboards +from bioscancast.stages.search_stage.date_recovery import recover_published_date from bioscancast.stages.search_stage.query_decomposition import SubQuery, decompose_question from bioscancast.stages.search_stage.tier_resolution import ( is_aggregator_domain, @@ -30,15 +32,52 @@ # File extensions that indicate non-content resources _NON_CONTENT_EXTENSIONS: set[str] = {".zip", ".exe", ".msi", ".dmg", ".tar", ".gz", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".mp4", ".mp3"} +# Default lookback window for historical-replay mode. Tavily's news endpoint +# requires both start_date and end_date to be set together (passing end_date +# alone is silently ignored — see ``backends/tavily_backend.py``). We synthesize +# a start_date 12 months before the cutoff: empirically (2026-05-20) this gives +# 20/20 native pre-cutoff hit rate on the resolved corpus without leaking past +# the cutoff. Tune via ``historical_lookback_days`` on the pipeline. +_DEFAULT_HISTORICAL_LOOKBACK_DAYS = 365 -def _compute_freshness(published_date: Optional[datetime]) -> float: + +def _should_use_wayback_for_recovery(r: SearchResult) -> bool: + """Selective gate for the Wayback first-seen leg of the date-recovery chain. + + Wayback CDX is rate-limited (~60 req/min server-side) and even with + proactive throttling each call costs us a few seconds. For undated + results that would be dropped on quality grounds anyway — aggregators + and unknown-tier domains — there is no recall benefit to paying that + cost. The URL-slug regex and Last-Modified strategies still run; only + the Wayback leg is gated. + """ + domain = extract_domain(r.url) + if is_aggregator_domain(domain): + logger.debug("Date recovery: skipping Wayback for aggregator %s", domain) + return False + if (r.source_tier or "").lower() == "unknown": + logger.debug("Date recovery: skipping Wayback for unknown-tier %s", domain) + return False + return True + + +def _compute_freshness( + published_date: Optional[datetime], + *, + reference_date: Optional[datetime] = None, +) -> float: """Compute freshness score from published_date. - Returns 0.5 (neutral) when no date is available, per spec. + Returns 0.5 (neutral) when no date is available, per spec. ``reference_date`` + is the "now" against which age is measured; in historical-replay mode the + pipeline passes ``question.as_of_date`` so freshness is judged from the + human forecaster's vantage point. Defaults to wall-clock ``now`` for + live mode. """ if published_date is None: return 0.5 - days_old = (datetime.now(timezone.utc) - published_date).days + ref = reference_date or datetime.now(timezone.utc) + days_old = (ref - published_date).days if days_old < 0: return 1.0 return max(0.0, min(1.0, 1.0 - (days_old / 365.0))) @@ -52,7 +91,14 @@ def _compute_search_stage_score(domain_score: float, freshness_score: float, ran def _parse_published_date(date_str: Optional[str]) -> Optional[datetime]: - """Best-effort parse of backend-provided published_date strings.""" + """Best-effort parse of backend-provided published_date strings. + + Tavily inconsistently returns either ISO-8601 (``2025-02-17`` or + ``2025-02-17T13:00:00+00:00``) or RFC 2822 (``Tue, 19 May 2026 13:00:00 + GMT``) depending on the search topic, so we try both. Returning None + here is expensive in historical mode (it triggers the date-recovery + chain), so it matters that we cover the formats Tavily actually emits. + """ if not date_str: return None for fmt in ("%Y-%m-%dT%H:%M:%S%z", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d"): @@ -63,7 +109,16 @@ def _parse_published_date(date_str: Optional[str]) -> Optional[datetime]: return dt except ValueError: continue - return None + # RFC 2822 fallback — what Tavily's news topic actually returns. + try: + dt = parsedate_to_datetime(date_str) + if dt is None: + return None + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + except (TypeError, ValueError): + return None def _is_non_content_url(url: str) -> bool: @@ -73,7 +128,25 @@ def _is_non_content_url(url: str) -> bool: class SearchStagePipeline: - """Orchestrates the full search stage: decompose → search → score → deduplicate.""" + """Orchestrates the full search stage: decompose → search → score → deduplicate. + + Historical-replay mode is activated implicitly by ``question.as_of_date``. + When that field is non-None: + + * the Tavily/CSE backend receives ``end_date=as_of_date`` and is asked to + restrict results to pages dated on or before the cutoff, + * the cache key incorporates the cutoff so replay runs don't see each + other's results, + * freshness scoring uses the cutoff as "now" rather than wall-clock time, + * a post-retrieval filter drops anything dated after the cutoff and any + undated result whose date can't be recovered from a cheap fallback, + * the dashboard injection rewrites URLs to closest Wayback snapshots, + suppressing dashboards with no pre-cutoff snapshot entirely, + * (opt-in) the LLM decomposition prompt is asked to roleplay the cutoff. + + The ``historical_roleplay`` constructor flag controls only the last item; + everything else is implicit on ``as_of_date``. + """ def __init__( self, @@ -83,6 +156,11 @@ def __init__( results_per_query: int = 10, total_cap: int = 60, backend_name: str = "tavily", + historical_roleplay: bool = False, + min_post_filter_results: int = 10, + top_up_results_per_query: int = 50, + max_top_up_rounds: int = 1, + historical_lookback_days: int = _DEFAULT_HISTORICAL_LOOKBACK_DAYS, ) -> None: self._backend = search_backend self._llm = llm_client @@ -90,70 +168,221 @@ def __init__( self._results_per_query = results_per_query self._total_cap = total_cap self._backend_name = backend_name + self._historical_roleplay = historical_roleplay + # Top-up parameters apply only in historical-replay mode. In live + # mode the initial pass is always considered sufficient. + self._min_post_filter_results = min_post_filter_results + self._top_up_results_per_query = top_up_results_per_query + self._max_top_up_rounds = max_top_up_rounds + # In historical-replay mode the backend receives end_date=as_of_date + # and start_date=as_of_date-lookback. Tavily requires the pair; see + # the module-level note on ``_DEFAULT_HISTORICAL_LOOKBACK_DAYS``. + self._historical_lookback_days = historical_lookback_days def run(self, question: ForecastQuestion) -> List[SearchResult]: """Execute the full search stage pipeline.""" + as_of = question.as_of_date + # 1. Decompose question into sub-queries - sub_queries = decompose_question(question, self._llm) + sub_queries = decompose_question( + question, + self._llm, + historical_roleplay=self._historical_roleplay, + ) logger.info("Decomposed into %d sub-queries", len(sub_queries)) # 2. Inject dashboard lookups dashboard_results = lookup_dashboards(question) logger.info("Dashboard lookup produced %d results", len(dashboard_results)) - # 3. Execute searches per sub-query + # 3. Execute initial search round all_results: list[SearchResult] = list(dashboard_results) - for sq in sub_queries: - raw_results = self._execute_search(sq.text) - for rank_offset, raw in enumerate(raw_results): - result = self._convert(raw, sq, question.id, rank_offset + 1) - all_results.append(result) - - if len(all_results) >= self._total_cap: - logger.info("Hit total cap of %d results before all sub-queries", self._total_cap) - break - - # 4. Deduplicate - deduped = self._deduplicate(all_results) - - # 5. Hard exclusions - filtered = self._apply_exclusions(deduped) + seen_canonical: set[str] = set() + for r in dashboard_results: + if r.canonical_url: + seen_canonical.add(r.canonical_url) + + all_results, seen_canonical = self._search_round( + sub_queries, + question, + as_of, + max_results=self._results_per_query, + collected=all_results, + seen_canonical=seen_canonical, + stop_cap=self._total_cap, + ) - # 6. Compute search_stage_score + # 4-6. Dedup → exclusions → cutoff filter + filtered = self._dedup_exclude_cutoff(all_results, as_of) + + # 6b. Top-up: in historical mode only, if we're below the survivor + # threshold, run additional rounds with a larger results_per_query + # to fish for more in-window content. With the start_date+end_date + # pair now forwarded to Tavily (see backends/tavily_backend.py), + # the candidate pool is already date-filtered upstream; top-up + # mostly compensates for results dropped by deduplication and the + # blocked-domain filter. + if as_of is not None: + rounds_done = 0 + while ( + rounds_done < self._max_top_up_rounds + and len(filtered) < self._min_post_filter_results + ): + rounds_done += 1 + logger.info( + "Historical top-up round %d: have %d survivors, want >= %d", + rounds_done, len(filtered), self._min_post_filter_results, + ) + all_results, seen_canonical = self._search_round( + sub_queries, + question, + as_of, + max_results=self._top_up_results_per_query, + collected=all_results, + seen_canonical=seen_canonical, + # Allow many more candidates than the final cap because + # most will be dropped by the cutoff filter. + stop_cap=self._total_cap * 10, + ) + filtered = self._dedup_exclude_cutoff(all_results, as_of) + + if len(filtered) < self._min_post_filter_results: + logger.warning( + "Historical top-up exhausted: %d survivors after %d round(s) " + "(target was %d). Returning what we have.", + len(filtered), rounds_done, self._min_post_filter_results, + ) + + # 7. Compute search_stage_score (freshness measured from cutoff in + # historical mode, wall-clock in live mode) for r in filtered: + r.freshness_score = _compute_freshness( + r.published_date, reference_date=as_of + ) r.search_stage_score = _compute_search_stage_score( r.domain_score, r.freshness_score, r.rank ) - # 7. Sort and cap + # 8. Sort and cap filtered.sort(key=lambda r: r.search_stage_score, reverse=True) result = filtered[: self._total_cap] logger.info("Search stage returning %d results", len(result)) return result - def _execute_search(self, query: str) -> List[RawSearchResult]: + def _search_round( + self, + sub_queries: List[SubQuery], + question: ForecastQuestion, + as_of: Optional[datetime], + *, + max_results: int, + collected: list[SearchResult], + seen_canonical: set[str], + stop_cap: int, + ) -> tuple[list[SearchResult], set[str]]: + """Issue each sub-query and append converted SearchResults to + ``collected``, skipping any URL already in ``seen_canonical``. + Returns the updated list and seen-set. Stops early when the + collected list reaches ``stop_cap``. + """ + for sq in sub_queries: + query_text = self._apply_year_hint(sq.text, as_of) + raw_results = self._execute_search( + query_text, as_of_date=as_of, max_results=max_results + ) + for rank_offset, raw in enumerate(raw_results): + canonical = normalize_url(raw.url) + if canonical and canonical in seen_canonical: + continue + result = self._convert(raw, sq, question.id, rank_offset + 1, as_of) + collected.append(result) + if canonical: + seen_canonical.add(canonical) + if len(collected) >= stop_cap: + logger.info( + "Stopping search round at %d collected results (cap=%d)", + len(collected), stop_cap, + ) + break + return collected, seen_canonical + + def _dedup_exclude_cutoff( + self, results: list[SearchResult], as_of: Optional[datetime] + ) -> list[SearchResult]: + """Run dedup → hard exclusions → cutoff filter (historical mode only).""" + deduped = self._deduplicate(results) + filtered = self._apply_exclusions(deduped) + if as_of is not None: + filtered = self._apply_cutoff_filter(filtered, as_of) + return filtered + + @staticmethod + def _apply_year_hint(query: str, as_of: Optional[datetime]) -> str: + """In historical mode, append the cutoff year to the query so the + search backend's lexical match biases toward dated content. The + start_date+end_date pair forwarded to Tavily already filters by + publication date, but the year hint reinforces topical relevance + within the window (Tavily's in-window ranking can still surface + irrelevant dated-correct results on cold or sparse queries). No-op + in live mode.""" + if as_of is None: + return query + year = as_of.year + # Avoid double-hinting if the LLM already put the year in. + if str(year) in query: + return query + return f"{query} {year}" + + def _execute_search( + self, + query: str, + as_of_date: Optional[datetime] = None, + max_results: Optional[int] = None, + ) -> List[RawSearchResult]: # TODO: multilingual support + # In historical-replay mode we pass BOTH start_date and end_date. + # Tavily silently ignores end_date when start_date is missing + # (verified 2026-05-20, specs/tavily-investigation-findings.md). + end_date_str: Optional[str] = None + start_date_str: Optional[str] = None + if as_of_date is not None: + end_date_str = as_of_date.strftime("%Y-%m-%d") + start_date_str = ( + as_of_date - timedelta(days=self._historical_lookback_days) + ).strftime("%Y-%m-%d") + effective_max = max_results if max_results is not None else self._results_per_query if self._cache: - cached = self._cache.get(self._backend_name, query) + cached = self._cache.get(self._backend_name, query, as_of_date=as_of_date) if cached is not None: logger.debug("Cache hit for query: %s", query) return cached - results = self._backend.search(query, max_results=self._results_per_query) + results = self._backend.search( + query, + max_results=effective_max, + end_date=end_date_str, + start_date=start_date_str, + ) if self._cache: - self._cache.put(self._backend_name, query, results) + self._cache.put(self._backend_name, query, results, as_of_date=as_of_date) return results def _convert( - self, raw: RawSearchResult, sub_query: SubQuery, question_id: str, rank: int + self, + raw: RawSearchResult, + sub_query: SubQuery, + question_id: str, + rank: int, + as_of_date: Optional[datetime] = None, ) -> SearchResult: domain = extract_domain(raw.url) canonical = normalize_url(raw.url) tier_num, domain_score, source_tier = resolve_tier(domain) published = _parse_published_date(raw.published_date) - freshness = _compute_freshness(published) + freshness = _compute_freshness(published, reference_date=as_of_date) + published_date_source = "backend" if published is not None else None return SearchResult( id=uuid.uuid4().hex, @@ -177,6 +406,8 @@ def _convert( # kept in results so downstream analysis can measure contamination effects. contains_aggregator_forecast=is_aggregator_domain(domain), search_stage_score=0.0, # computed after dedup + published_date_source=published_date_source, + cutoff_applied=as_of_date, ) def _deduplicate(self, results: List[SearchResult]) -> List[SearchResult]: @@ -219,6 +450,69 @@ def _apply_exclusions(self, results: List[SearchResult]) -> List[SearchResult]: kept.append(r) return kept + def _apply_cutoff_filter( + self, results: List[SearchResult], as_of: datetime + ) -> List[SearchResult]: + """Historical-replay mode: keep only results that demonstrably existed + before ``as_of``. Drop post-cutoff and undatable results. + + Wayback-snapshot dashboards already have ``published_date`` set to the + capture timestamp by ``dashboard_lookup``; this filter is therefore + idempotent on them. + """ + dropped_post_cutoff = 0 + dropped_undatable = 0 + recovered = 0 + wayback_skipped = 0 + kept: list[SearchResult] = [] + for r in results: + if r.published_date is not None: + if r.published_date > as_of: + dropped_post_cutoff += 1 + logger.debug( + "Cutoff filter: dropping post-cutoff %s (pub=%s, cutoff=%s)", + r.url, r.published_date.isoformat(), as_of.isoformat(), + ) + continue + kept.append(r) + continue + + # Undated — try the recovery chain. Skip the Wayback first-seen + # leg for aggregator domains and unknown-tier sources: those + # results would be dropped on quality grounds anyway, and the + # CDX call (even with throttling) costs us several seconds each. + use_wayback = _should_use_wayback_for_recovery(r) + if not use_wayback: + wayback_skipped += 1 + recovered_date, source = recover_published_date( + r.url, use_wayback=use_wayback + ) + if recovered_date is None: + dropped_undatable += 1 + logger.debug( + "Cutoff filter: dropping %s (no_date_available)", r.url + ) + continue + if recovered_date > as_of: + dropped_post_cutoff += 1 + logger.debug( + "Cutoff filter: recovered date %s > cutoff for %s", + recovered_date.isoformat(), r.url, + ) + continue + r.published_date = recovered_date + r.published_date_source = source + recovered += 1 + kept.append(r) + + logger.info( + "Cutoff filter: kept=%d, recovered=%d, dropped_post_cutoff=%d, " + "dropped_undatable=%d, wayback_skipped=%d (cutoff=%s)", + len(kept), recovered, dropped_post_cutoff, dropped_undatable, + wayback_skipped, as_of.isoformat(), + ) + return kept + def run_search_stage( question: ForecastQuestion, @@ -226,6 +520,7 @@ def run_search_stage( llm_client: LLMClient, cache: Optional[SearchCache] = None, backend_name: str = "tavily", + historical_roleplay: bool = False, ) -> List[SearchResult]: """Convenience function to run the search stage pipeline.""" pipeline = SearchStagePipeline( @@ -233,5 +528,6 @@ def run_search_stage( llm_client=llm_client, cache=cache, backend_name=backend_name, + historical_roleplay=historical_roleplay, ) return pipeline.run(question) diff --git a/bioscancast/stages/search_stage/query_decomposition.py b/bioscancast/stages/search_stage/query_decomposition.py index 7456cf0..a1b14d9 100644 --- a/bioscancast/stages/search_stage/query_decomposition.py +++ b/bioscancast/stages/search_stage/query_decomposition.py @@ -88,16 +88,29 @@ def classify_question_type(question: ForecastQuestion, llm_client: LLMClient) -> return "unknown" -def _build_decomposition_prompt(question: ForecastQuestion, question_type: str) -> str: +def _build_decomposition_prompt( + question: ForecastQuestion, + question_type: str, + historical_roleplay: bool = False, +) -> str: axes = AXES_BY_TYPE.get(question_type, list(VALID_AXES)) + task_lines = [ + "Decompose this biosecurity forecast question into 5-8 search-engine-optimised " + "sub-queries. Each sub-query should be 2-8 words and target a specific information " + "axis. Return strict JSON: {\"sub_queries\": [{\"text\": \"...\", \"axis\": \"...\"}]}. " + "No prose." + ] + if historical_roleplay and question.as_of_date is not None: + task_lines.append( + "IMPORTANT: Generate sub-queries as if today were " + f"{question.as_of_date.date().isoformat()}. Do not assume knowledge " + "of events, named entities, or facts that you only learned about " + "after that date. Phrase queries in terms a forecaster on that " + "date would have used." + ) return json.dumps( { - "task": ( - "Decompose this biosecurity forecast question into 5-8 search-engine-optimised " - "sub-queries. Each sub-query should be 2-8 words and target a specific information " - "axis. Return strict JSON: {\"sub_queries\": [{\"text\": \"...\", \"axis\": \"...\"}]}. " - "No prose." - ), + "task": " ".join(task_lines), "question": question.text, "pathogen": question.pathogen, "region": question.region, @@ -158,14 +171,25 @@ def _fallback_subqueries(question: ForecastQuestion) -> List[SubQuery]: def decompose_question( - question: ForecastQuestion, llm_client: LLMClient + question: ForecastQuestion, + llm_client: LLMClient, + *, + historical_roleplay: bool = False, ) -> List[SubQuery]: """Decompose a forecast question into sub-queries using an LLM. Falls back to simple keyword-based sub-queries if the LLM fails. + + ``historical_roleplay`` is an opt-in benchmark-only flag. When True AND + ``question.as_of_date`` is set, the prompt is extended with an instruction + asking the LLM to query as if today were the cutoff date. This is gated + behind its own flag because prompt-level roleplay can have hard-to-predict + effects on query quality. """ question_type = classify_question_type(question, llm_client) - prompt = _build_decomposition_prompt(question, question_type) + prompt = _build_decomposition_prompt( + question, question_type, historical_roleplay=historical_roleplay + ) try: result = llm_client.generate_json(prompt) diff --git a/bioscancast/stages/search_stage/wayback.py b/bioscancast/stages/search_stage/wayback.py new file mode 100644 index 0000000..b17bc88 --- /dev/null +++ b/bioscancast/stages/search_stage/wayback.py @@ -0,0 +1,226 @@ +"""Tiny Wayback Machine CDX client used by historical-replay mode. + +Two callers need this: + +* the search-stage dashboard rewrite (closest snapshot at-or-before cutoff), +* the extraction-stage fetcher (same lookup, then fetch the snapshot bytes), +* and the date-recovery chain in the search stage (first-seen capture). + +The implementation deliberately uses stdlib ``urllib`` rather than ``curl_cffi``: +the Wayback CDX endpoint returns a small JSON document, is not protected by +Cloudflare/JA3 filters, and adding ``curl_cffi`` as a dependency of the +search stage would broaden the dependency surface unnecessarily. + +All functions return ``None`` on any network/parse error and log at WARNING. +Callers must tolerate ``None`` — Wayback is best-effort. +""" + +from __future__ import annotations + +import json +import logging +import os +import socket +import threading +import time +import urllib.error +import urllib.parse +import urllib.request +from datetime import datetime, timezone +from typing import Optional, Tuple + +logger = logging.getLogger(__name__) + +CDX_ENDPOINT = "https://web.archive.org/cdx/search/cdx" +SNAPSHOT_TEMPLATE = "https://web.archive.org/web/{timestamp}/{url}" +# The ``id_`` modifier on the timestamp returns the raw original bytes +# (no Wayback toolbar / no rewriting). Use this when we want to feed the +# response straight into our HTML/PDF parsers. +RAW_SNAPSHOT_TEMPLATE = "https://web.archive.org/web/{timestamp}id_/{url}" + +_REQUEST_TIMEOUT_SECONDS = 30 + +# Retry schedule for the Wayback CDX endpoint. The endpoint frequently +# returns HTTP 503 or times out under load — empirically a single +# benchmark run can trigger dozens of consecutive failures. We trade +# wall-clock time for completeness because historical-replay benchmarks +# are not latency-sensitive. Pre-attempt delays in seconds; the i-th +# entry is the wait BEFORE attempt i (so the first attempt fires +# immediately). Override at module level in tests to keep them fast. +RETRY_BACKOFF_SECONDS: tuple[float, ...] = (0, 10, 30, 90, 240) + +# Recoverable HTTP status codes that warrant a retry. +_RECOVERABLE_STATUSES = {429, 500, 502, 503, 504} + +# Minimum interval between successive outbound CDX calls. Internet Archive +# rate-limits CDX at ~60 req/min server-side; the widely-used edgi-govdata +# Python client paces at ~0.8 req/s (1.25 s) by default. We sit at 2.0 s +# (30 req/min) — comfortably under the server cap with headroom for bursts, +# but ~2x throughput vs the initial conservative 4 s setting once we +# confirmed the throttle eliminates 429s in practice. Override via env var +# ``BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS`` for ad-hoc tuning. +_DEFAULT_MIN_INTERVAL_SECONDS = 2.0 +_MIN_INTERVAL_ENV_VAR = "BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS" +_throttle_lock = threading.Lock() +_last_call_monotonic: float = 0.0 + + +def _min_interval_seconds() -> float: + raw = os.environ.get(_MIN_INTERVAL_ENV_VAR) + if raw is None: + return _DEFAULT_MIN_INTERVAL_SECONDS + try: + return float(raw) + except ValueError: + logger.warning( + "Invalid %s=%r; using default %.1fs", + _MIN_INTERVAL_ENV_VAR, raw, _DEFAULT_MIN_INTERVAL_SECONDS, + ) + return _DEFAULT_MIN_INTERVAL_SECONDS + + +def _sleep(seconds: float) -> None: + """Indirection so tests can monkeypatch a no-op sleep.""" + if seconds > 0: + time.sleep(seconds) + + +def _throttle() -> None: + """Block until the configured min interval since the last CDX call has elapsed.""" + global _last_call_monotonic + min_interval = _min_interval_seconds() + with _throttle_lock: + elapsed = time.monotonic() - _last_call_monotonic + wait = min_interval - elapsed + if wait > 0: + _sleep(wait) + _last_call_monotonic = time.monotonic() + + +def _cdx_query(params: dict) -> Optional[list]: + """POST-free GET against the CDX endpoint. Returns the parsed JSON list, + or None on any failure. Retries on HTTP 503/429/5xx and read timeouts + according to ``RETRY_BACKOFF_SECONDS``.""" + query = urllib.parse.urlencode(params) + full_url = f"{CDX_ENDPOINT}?{query}" + + body: Optional[str] = None + for attempt, pre_delay in enumerate(RETRY_BACKOFF_SECONDS, start=1): + if pre_delay: + logger.info( + "Wayback CDX backoff %.0fs before attempt %d/%d", + pre_delay, attempt, len(RETRY_BACKOFF_SECONDS), + ) + _sleep(pre_delay) + _throttle() + try: + req = urllib.request.Request( + full_url, headers={"User-Agent": "BioScanCast/replay (+wayback-cdx)"} + ) + with urllib.request.urlopen(req, timeout=_REQUEST_TIMEOUT_SECONDS) as resp: + body = resp.read().decode("utf-8", errors="replace") + break # success + except urllib.error.HTTPError as exc: + if exc.code in _RECOVERABLE_STATUSES and attempt < len(RETRY_BACKOFF_SECONDS): + logger.info( + "Wayback CDX HTTP %d on attempt %d; retrying", + exc.code, attempt, + ) + continue + logger.warning( + "Wayback CDX gave up after %d attempt(s): HTTP %d for %s", + attempt, exc.code, full_url, + ) + return None + except (socket.timeout, TimeoutError, urllib.error.URLError) as exc: + # urllib.error.URLError wraps socket.timeout on read timeouts in + # some Python builds; check both. + is_timeout = isinstance(exc, (socket.timeout, TimeoutError)) or ( + isinstance(exc, urllib.error.URLError) + and isinstance(getattr(exc, "reason", None), (socket.timeout, TimeoutError)) + ) + if is_timeout and attempt < len(RETRY_BACKOFF_SECONDS): + logger.info( + "Wayback CDX timeout on attempt %d; retrying", attempt, + ) + continue + logger.warning( + "Wayback CDX gave up after %d attempt(s): %s for %s", + attempt, exc, full_url, + ) + return None + except Exception as exc: + logger.warning( + "Wayback CDX non-recoverable error: %s for %s", exc, full_url + ) + return None + + if body is None: + return None + + if not body.strip(): + return [] + try: + data = json.loads(body) + except json.JSONDecodeError: + logger.warning("Wayback CDX returned non-JSON body for %s", full_url) + return None + + # First row is the header. Drop it. + if isinstance(data, list) and data and isinstance(data[0], list): + return data[1:] + return [] + + +def _parse_cdx_timestamp(ts: str) -> Optional[datetime]: + """CDX timestamps are YYYYMMDDhhmmss in UTC.""" + try: + return datetime.strptime(ts, "%Y%m%d%H%M%S").replace(tzinfo=timezone.utc) + except ValueError: + return None + + +def closest_snapshot_before( + url: str, as_of: datetime +) -> Optional[Tuple[datetime, str]]: + """Return (snapshot_datetime, raw_snapshot_url) for the latest Wayback + capture of ``url`` whose timestamp is ``<= as_of``. Returns None when no + suitable snapshot exists or the lookup fails. + + The returned URL uses the ``id_`` modifier so callers get unwrapped + original content (no Wayback chrome). + """ + to_param = as_of.astimezone(timezone.utc).strftime("%Y%m%d%H%M%S") + rows = _cdx_query( + { + "url": url, + "to": to_param, + "limit": "-1", # most recent matching row + "output": "json", + "filter": "statuscode:200", + } + ) + if not rows: + return None + timestamp = rows[0][1] # column 1 is the capture timestamp + parsed = _parse_cdx_timestamp(timestamp) + if parsed is None: + return None + snapshot_url = RAW_SNAPSHOT_TEMPLATE.format(timestamp=timestamp, url=url) + return parsed, snapshot_url + + +def first_seen(url: str) -> Optional[datetime]: + """Return the earliest Wayback capture timestamp for ``url``, or None.""" + rows = _cdx_query( + { + "url": url, + "limit": "1", + "output": "json", + "filter": "statuscode:200", + "sort": "ascending", + } + ) + if not rows: + return None + return _parse_cdx_timestamp(rows[0][1]) diff --git a/bioscancast/tests/test_contamination_metrics.py b/bioscancast/tests/test_contamination_metrics.py new file mode 100644 index 0000000..c48829f --- /dev/null +++ b/bioscancast/tests/test_contamination_metrics.py @@ -0,0 +1,132 @@ +from datetime import datetime, timezone + +from bioscancast.filtering.models import ForecastQuestion, SearchResult +from bioscancast.stages.eval_stage.contamination import ( + BaselineForecast, + ContaminationCounts, + filter_caught_contamination_rate, + retrieval_free_baseline_forecast, +) + + +def _result(pub: datetime | None) -> SearchResult: + return SearchResult( + id="x", + question_id="Q", + query_id="q1", + engine="fake", + url="https://example.com/x", + canonical_url="https://example.com/x", + domain="example.com", + title="T", + snippet="S", + rank=1, + retrieved_at=datetime.now(timezone.utc), + published_date=pub, + ) + + +class TestFilterCaughtContaminationRate: + def test_clean_run_is_zero(self): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + results = [ + _result(datetime(2024, 1, 1, tzinfo=timezone.utc)), + _result(datetime(2024, 5, 31, tzinfo=timezone.utc)), + ] + counts = filter_caught_contamination_rate(results, cutoff) + assert counts.post_cutoff_in_final == 0 + assert counts.filter_caught_rate == 0.0 + assert counts.pre_cutoff_in_final == 2 + + def test_some_leak_through(self): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + results = [ + _result(datetime(2024, 5, 1, tzinfo=timezone.utc)), + _result(datetime(2024, 8, 1, tzinfo=timezone.utc)), # post + _result(datetime(2024, 9, 1, tzinfo=timezone.utc)), # post + _result(None), # undated + ] + counts = filter_caught_contamination_rate(results, cutoff) + assert counts.post_cutoff_in_final == 2 + assert counts.undated_in_final == 1 + assert counts.pre_cutoff_in_final == 1 + assert counts.filter_caught_rate == 0.5 + + def test_empty_list(self): + counts = filter_caught_contamination_rate( + [], datetime(2024, 1, 1, tzinfo=timezone.utc) + ) + assert counts.filter_caught_rate == 0.0 + assert counts.total == 0 + + +class TestRetrievalFreeBaselineForecast: + def test_well_formed_response(self): + question = ForecastQuestion( + id="Q1", + text="Will X happen?", + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + as_of_date=datetime(2024, 6, 1, tzinfo=timezone.utc), + ) + + class GoodLLM: + def generate_json(self, prompt): + return {"probabilities": [0.7, 0.3], "rationale": "guess"} + + out = retrieval_free_baseline_forecast( + question, options=["yes", "no"], llm_client=GoodLLM() + ) + assert isinstance(out, BaselineForecast) + assert abs(sum(out.probabilities) - 1.0) < 1e-9 + assert out.probabilities[0] > out.probabilities[1] + assert out.rationale == "guess" + + def test_renormalises_unnormalised_probabilities(self): + question = ForecastQuestion( + id="Q1", + text="?", + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + + class UnnormLLM: + def generate_json(self, prompt): + return {"probabilities": [2.0, 6.0], "rationale": ""} + + out = retrieval_free_baseline_forecast( + question, options=["a", "b"], llm_client=UnnormLLM() + ) + assert abs(sum(out.probabilities) - 1.0) < 1e-9 + assert abs(out.probabilities[0] - 0.25) < 1e-9 + + def test_malformed_response_uniform_fallback(self): + question = ForecastQuestion( + id="Q1", + text="?", + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + + class BadLLM: + def generate_json(self, prompt): + return {"probabilities": "not a list"} + + out = retrieval_free_baseline_forecast( + question, options=["a", "b", "c"], llm_client=BadLLM() + ) + assert out.probabilities == [1 / 3, 1 / 3, 1 / 3] + assert "fallback" in (out.rationale or "") + + def test_llm_exception_uniform_fallback(self): + question = ForecastQuestion( + id="Q1", + text="?", + created_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + ) + + class ExplodingLLM: + def generate_json(self, prompt): + raise RuntimeError("oops") + + out = retrieval_free_baseline_forecast( + question, options=["a", "b"], llm_client=ExplodingLLM() + ) + assert out.probabilities == [0.5, 0.5] diff --git a/bioscancast/tests/test_cutoff_filtering.py b/bioscancast/tests/test_cutoff_filtering.py new file mode 100644 index 0000000..fd30137 --- /dev/null +++ b/bioscancast/tests/test_cutoff_filtering.py @@ -0,0 +1,410 @@ +"""End-to-end-ish tests for historical-replay mode in SearchStagePipeline. + +Uses the same FakeLLMClient/FakeSearchBackend pattern as test_search_pipeline.py +to keep the test layer hand-rolled and dependency-free. +""" + +from datetime import datetime, timezone +from typing import List +from unittest.mock import patch + +from bioscancast.filtering.models import ForecastQuestion, SearchResult +from bioscancast.stages.search_stage.backends.base import RawSearchResult +from bioscancast.stages.search_stage.pipeline import ( + SearchStagePipeline, + _parse_published_date, + _should_use_wayback_for_recovery, +) + + +class TestParsePublishedDate: + def test_iso_with_offset(self): + assert _parse_published_date("2025-02-17T13:00:00+00:00") == datetime( + 2025, 2, 17, 13, 0, 0, tzinfo=timezone.utc + ) + + def test_iso_date_only(self): + assert _parse_published_date("2025-02-17") == datetime( + 2025, 2, 17, tzinfo=timezone.utc + ) + + def test_rfc2822_with_zone(self): + # The format Tavily's news topic actually returns. + result = _parse_published_date("Tue, 19 May 2026 13:00:00 GMT") + assert result is not None + assert result.year == 2026 + assert result.month == 5 + assert result.day == 19 + assert result.tzinfo is not None + + def test_rfc2822_with_offset(self): + result = _parse_published_date("Tue, 19 May 2026 13:00:00 +0000") + assert result is not None + assert result.day == 19 + + def test_none_and_empty(self): + assert _parse_published_date(None) is None + assert _parse_published_date("") is None + + def test_garbage_returns_none(self): + assert _parse_published_date("not a date") is None + + +class _FakeLLM: + def __init__(self): + self._calls = 0 + + def generate_json(self, prompt: str) -> dict: + self._calls += 1 + if self._calls == 1: + return {"question_type": "outbreak_count"} + return { + "sub_queries": [ + {"text": "H5N1 cases 2024", "axis": "latest_data"}, + {"text": "avian flu trend", "axis": "trend"}, + {"text": "bird flu policy", "axis": "policy"}, + ] + } + + +class _FakeBackend: + def __init__(self, results: List[RawSearchResult]): + self._results = results + self.end_dates_seen: list = [] + self.start_dates_seen: list = [] + + def search(self, query, max_results=10, end_date=None, start_date=None): + self.end_dates_seen.append(end_date) + self.start_dates_seen.append(start_date) + return list(self._results) + + +def _make_question(as_of: datetime | None) -> ForecastQuestion: + return ForecastQuestion( + id="Q-CUT", + text="Will H5N1 exceed 100 cases by end of 2024?", + created_at=datetime(2024, 6, 1, tzinfo=timezone.utc), + pathogen="nopathogen", # avoid dashboard injection in this test + as_of_date=as_of, + ) + + +def test_post_cutoff_results_are_dropped(): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend_results = [ + RawSearchResult( + url="https://news.example.com/a", + title="Pre-cutoff", + snippet="", + rank=1, + published_date="2024-05-15", + ), + RawSearchResult( + url="https://news.example.com/b", + title="Post-cutoff", + snippet="", + rank=2, + published_date="2024-08-15", + ), + ] + pipeline = SearchStagePipeline( + search_backend=_FakeBackend(backend_results), + llm_client=_FakeLLM(), + backend_name="fake", + ) + results = pipeline.run(_make_question(cutoff)) + urls = {r.url for r in results} + assert "https://news.example.com/a" in urls + assert "https://news.example.com/b" not in urls + + +def test_undated_dropped_when_recovery_fails(): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend_results = [ + RawSearchResult( + url="https://news.example.com/no-date", + title="Undated", + snippet="", + rank=1, + published_date=None, + ), + ] + pipeline = SearchStagePipeline( + search_backend=_FakeBackend(backend_results), + llm_client=_FakeLLM(), + backend_name="fake", + ) + with patch( + "bioscancast.stages.search_stage.pipeline.recover_published_date" + ) as mock_rec: + mock_rec.return_value = (None, None) + results = pipeline.run(_make_question(cutoff)) + assert not any(r.url == "https://news.example.com/no-date" for r in results) + + +def test_undated_kept_when_recovery_succeeds_before_cutoff(): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend_results = [ + RawSearchResult( + url="https://news.example.com/2024/03/15/article", + title="Slug-dated", + snippet="", + rank=1, + published_date=None, + ), + ] + pipeline = SearchStagePipeline( + search_backend=_FakeBackend(backend_results), + llm_client=_FakeLLM(), + backend_name="fake", + ) + results = pipeline.run(_make_question(cutoff)) + matching = [r for r in results if "2024/03/15" in r.url] + assert len(matching) == 1 + assert matching[0].published_date_source == "url_slug" + + +def test_end_date_forwarded_to_backend(): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend = _FakeBackend( + [ + RawSearchResult( + url="https://news.example.com/x", + title="X", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_FakeLLM(), backend_name="fake" + ) + pipeline.run(_make_question(cutoff)) + assert all(d == "2024-06-01" for d in backend.end_dates_seen if d is not None) + assert any(d == "2024-06-01" for d in backend.end_dates_seen) + + +def test_live_mode_unchanged(): + backend = _FakeBackend( + [ + RawSearchResult( + url="https://news.example.com/x", + title="X", + snippet="", + rank=1, + published_date=None, + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_FakeLLM(), backend_name="fake" + ) + results = pipeline.run(_make_question(as_of=None)) + # Undated result MUST be kept in live mode (the cutoff filter is off) + assert any(r.url == "https://news.example.com/x" for r in results) + # And backend received end_date=None AND start_date=None — Tavily ignores + # end_date when start_date is missing, so the pipeline must keep them + # both unset in live mode. + assert all(d is None for d in backend.end_dates_seen) + assert all(d is None for d in backend.start_dates_seen) + + +def test_historical_mode_forwards_start_and_end_date_pair(): + """Tavily honors end_date only when start_date is also set. The pipeline + must synthesize start_date = as_of - historical_lookback_days and pass + both to the backend on every search call.""" + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend = _FakeBackend( + [ + RawSearchResult( + url="https://news.example.com/x", + title="X", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=_FakeLLM(), + backend_name="fake", + historical_lookback_days=365, + ) + pipeline.run(_make_question(cutoff)) + # Every search call in historical mode must carry BOTH bounds. + paired = [ + (s, e) + for s, e in zip(backend.start_dates_seen, backend.end_dates_seen) + if s is not None or e is not None + ] + assert paired, "expected at least one date-bounded search in historical mode" + for start, end in paired: + assert start is not None and end is not None, ( + "Tavily ignores end_date alone — pipeline must pass the pair" + ) + assert end == "2024-06-01" + assert start == "2023-06-02" # 365 days before 2024-06-01 + + +def test_historical_lookback_days_is_configurable(): + """Override the default 365-day lookback via the pipeline constructor.""" + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend = _FakeBackend( + [ + RawSearchResult( + url="https://news.example.com/x", + title="X", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=_FakeLLM(), + backend_name="fake", + historical_lookback_days=30, + ) + pipeline.run(_make_question(cutoff)) + starts = [s for s in backend.start_dates_seen if s is not None] + assert starts and all(s == "2024-05-02" for s in starts) # 30 days before + + +def _make_search_result(url: str, source_tier: str = "trusted_media") -> SearchResult: + return SearchResult( + id="r1", + question_id="q1", + query_id="sq1", + engine="fake", + url=url, + canonical_url=None, + domain="", + title="t", + snippet="", + rank=1, + retrieved_at=datetime(2024, 1, 1, tzinfo=timezone.utc), + source_tier=source_tier, + ) + + +class TestSelectiveRecoveryGate: + """The Wayback-leg gate on the date-recovery chain.""" + + def test_official_tier_uses_wayback(self): + r = _make_search_result( + "https://www.cdc.gov/bird-flu/situation-summary/", source_tier="official" + ) + assert _should_use_wayback_for_recovery(r) is True + + def test_academic_tier_uses_wayback(self): + r = _make_search_result( + "https://www.nature.com/articles/xyz", source_tier="academic" + ) + assert _should_use_wayback_for_recovery(r) is True + + def test_unknown_tier_skips_wayback(self): + r = _make_search_result( + "https://obscure-site.example/article", source_tier="unknown" + ) + assert _should_use_wayback_for_recovery(r) is False + + def test_aggregator_domain_skips_wayback(self): + # metaculus.com is in AGGREGATOR_DOMAINS regardless of tier label. + r = _make_search_result( + "https://www.metaculus.com/questions/12345/", + source_tier="trusted_media", + ) + assert _should_use_wayback_for_recovery(r) is False + + +def test_aggregator_undated_recovery_skips_wayback(): + """End-to-end: an undated aggregator result with no slug date routes to + recover_published_date with use_wayback=False, so the Wayback leg never + fires for it.""" + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend_results = [ + RawSearchResult( + url="https://www.metaculus.com/questions/abc", # known aggregator + title="Aggregator forecast", + snippet="", + rank=1, + published_date=None, + ), + ] + pipeline = SearchStagePipeline( + search_backend=_FakeBackend(backend_results), + llm_client=_FakeLLM(), + backend_name="fake", + ) + with patch( + "bioscancast.stages.search_stage.pipeline.recover_published_date", + return_value=(None, None), + ) as mock_rec: + pipeline.run(_make_question(cutoff)) + # The recovery function was called, but with use_wayback=False. + assert mock_rec.called + # At least one of the calls was for the aggregator URL with use_wayback=False. + aggregator_calls = [ + c for c in mock_rec.call_args_list + if c.args and "metaculus.com" in c.args[0] + ] + assert aggregator_calls + for call in aggregator_calls: + assert call.kwargs.get("use_wayback") is False + + +def test_official_undated_recovery_still_tries_wayback(): + """A tier-1 official domain with no slug date should still hit the + Wayback leg of recovery (i.e., use_wayback=True).""" + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend_results = [ + RawSearchResult( + url="https://www.cdc.gov/some/article", # tier 1 official + title="CDC article", + snippet="", + rank=1, + published_date=None, + ), + ] + pipeline = SearchStagePipeline( + search_backend=_FakeBackend(backend_results), + llm_client=_FakeLLM(), + backend_name="fake", + ) + with patch( + "bioscancast.stages.search_stage.pipeline.recover_published_date", + return_value=(None, None), + ) as mock_rec: + pipeline.run(_make_question(cutoff)) + cdc_calls = [ + c for c in mock_rec.call_args_list + if c.args and "cdc.gov" in c.args[0] + ] + assert cdc_calls + for call in cdc_calls: + assert call.kwargs.get("use_wayback") is True + + +def test_cutoff_applied_persisted_on_results(): + cutoff = datetime(2024, 6, 1, tzinfo=timezone.utc) + backend = _FakeBackend( + [ + RawSearchResult( + url="https://news.example.com/x", + title="X", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_FakeLLM(), backend_name="fake" + ) + results = pipeline.run(_make_question(cutoff)) + assert results + for r in results: + assert r.cutoff_applied == cutoff diff --git a/bioscancast/tests/test_date_recovery.py b/bioscancast/tests/test_date_recovery.py new file mode 100644 index 0000000..54c142a --- /dev/null +++ b/bioscancast/tests/test_date_recovery.py @@ -0,0 +1,95 @@ +from datetime import datetime, timezone +from unittest.mock import patch + +from bioscancast.stages.search_stage.date_recovery import ( + date_from_last_modified, + date_from_url_slug, + recover_published_date, +) + + +class TestDateFromUrlSlug: + def test_year_month_day_path(self): + assert date_from_url_slug( + "https://example.com/2024/03/15/some-article" + ) == datetime(2024, 3, 15, tzinfo=timezone.utc) + + def test_year_month_only(self): + assert date_from_url_slug( + "https://example.com/news/2023/06/topic" + ) == datetime(2023, 6, 1, tzinfo=timezone.utc) + + def test_iso_dashed(self): + assert date_from_url_slug( + "https://example.com/p/2025-01-20/title" + ) == datetime(2025, 1, 20, tzinfo=timezone.utc) + + def test_no_match(self): + assert date_from_url_slug("https://example.com/about/contact") is None + + def test_implausible_year_rejected(self): + # 1872 looks like a year but is too old to be a sensible publication + assert date_from_url_slug("https://example.com/1872/03/15") is None + + +class TestDateFromLastModified: + def test_no_fetcher_returns_none(self): + # Off by default: requires explicit injection + assert date_from_last_modified("https://example.com/a") is None + + def test_rfc7231_format(self): + header = "Wed, 21 Oct 2015 07:28:00 GMT" + result = date_from_last_modified( + "https://example.com/a", head_fetcher=lambda _: header + ) + assert result == datetime(2015, 10, 21, 7, 28, 0, tzinfo=timezone.utc) + + def test_fetcher_returning_none(self): + assert date_from_last_modified( + "https://example.com/a", head_fetcher=lambda _: None + ) is None + + def test_fetcher_raises_returns_none(self): + def boom(_): + raise RuntimeError("network down") + + assert date_from_last_modified("https://example.com/a", head_fetcher=boom) is None + + def test_unparseable_header(self): + assert date_from_last_modified( + "https://example.com/a", head_fetcher=lambda _: "not a date" + ) is None + + +class TestRecoverPublishedDate: + def test_url_slug_wins(self): + dt, source = recover_published_date( + "https://example.com/2024/03/15/x", use_wayback=False + ) + assert source == "url_slug" + assert dt == datetime(2024, 3, 15, tzinfo=timezone.utc) + + def test_wayback_used_when_no_slug(self): + with patch( + "bioscancast.stages.search_stage.date_recovery._wayback_first_seen" + ) as mock_wb: + mock_wb.return_value = datetime(2020, 1, 1, tzinfo=timezone.utc) + dt, source = recover_published_date("https://example.com/about") + assert source == "wayback_first_seen" + assert dt == datetime(2020, 1, 1, tzinfo=timezone.utc) + + def test_all_strategies_fail(self): + with patch( + "bioscancast.stages.search_stage.date_recovery._wayback_first_seen" + ) as mock_wb: + mock_wb.return_value = None + dt, source = recover_published_date("https://example.com/about") + assert dt is None + assert source is None + + def test_wayback_disabled(self): + dt, source = recover_published_date( + "https://example.com/about", use_wayback=False + ) + assert dt is None + assert source is None diff --git a/bioscancast/tests/test_extraction_pipeline.py b/bioscancast/tests/test_extraction_pipeline.py index f7b99d0..4aba2d3 100644 --- a/bioscancast/tests/test_extraction_pipeline.py +++ b/bioscancast/tests/test_extraction_pipeline.py @@ -74,7 +74,7 @@ def _make_fetch_result( def _fake_fetch_factory(mapping: dict[str, FetchResult]): """Return a fetch function that looks up results by URL.""" - def fake_fetch(url, *, config=None): + def fake_fetch(url, *, config=None, as_of_date=None): if url in mapping: return mapping[url] return _make_fetch_result(url, b"", error="not_found") diff --git a/bioscancast/tests/test_historical_topup.py b/bioscancast/tests/test_historical_topup.py new file mode 100644 index 0000000..15fb484 --- /dev/null +++ b/bioscancast/tests/test_historical_topup.py @@ -0,0 +1,253 @@ +"""Tests for the historical-mode year-hint and top-up behavior in +SearchStagePipeline. +""" + +from datetime import datetime, timezone +from typing import List + +from bioscancast.filtering.models import ForecastQuestion +from bioscancast.stages.search_stage.backends.base import RawSearchResult +from bioscancast.stages.search_stage.pipeline import SearchStagePipeline + + +class _FakeLLM: + def __init__(self): + self._calls = 0 + + def generate_json(self, prompt: str) -> dict: + self._calls += 1 + if self._calls == 1: + return {"question_type": "outbreak_count"} + return { + "sub_queries": [ + {"text": "H5N1 cases", "axis": "latest_data"}, + {"text": "avian flu trend", "axis": "trend"}, + {"text": "bird flu policy", "axis": "policy"}, + ] + } + + +class _RecordingBackend: + """Records every (query, max_results) it was called with and returns a + canned mapping. Same URL can appear across calls — pipeline dedup + handles that.""" + + def __init__(self, results_by_query: dict[tuple[str, int], List[RawSearchResult]] | None = None): + self.calls: list[tuple[str, int]] = [] + self._results = results_by_query or {} + # Fallback results for any query not explicitly mapped. + self._fallback: List[RawSearchResult] = [] + + def set_fallback(self, results: List[RawSearchResult]) -> None: + self._fallback = results + + def search(self, query, max_results=10, end_date=None, start_date=None): + self.calls.append((query, max_results)) + # Prefer exact match on (query, max_results); else any match on + # query; else fallback. + if (query, max_results) in self._results: + return list(self._results[(query, max_results)]) + for (q, _), res in self._results.items(): + if q == query: + return list(res) + return list(self._fallback) + + +def _question(as_of: datetime | None) -> ForecastQuestion: + return ForecastQuestion( + id="Q-TU", + text="H5N1 outbreak in 2024", + created_at=datetime(2024, 6, 1, tzinfo=timezone.utc), + pathogen="nopathogen", # skip dashboard lookup + as_of_date=as_of, + ) + + +def test_year_hint_appended_in_historical_mode(): + backend = _RecordingBackend() + backend.set_fallback( + [ + RawSearchResult( + url="https://news.example.com/a", + title="A", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_FakeLLM(), backend_name="fake" + ) + pipeline.run(_question(datetime(2024, 6, 1, tzinfo=timezone.utc))) + # Every query the backend saw should end in " 2024" + assert backend.calls, "backend should have been called" + queries = [q for q, _ in backend.calls] + assert all(q.endswith(" 2024") for q in queries), queries + + +def test_year_hint_skipped_in_live_mode(): + backend = _RecordingBackend() + backend.set_fallback( + [ + RawSearchResult( + url="https://news.example.com/a", + title="A", + snippet="", + rank=1, + published_date=None, + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_FakeLLM(), backend_name="fake" + ) + pipeline.run(_question(as_of=None)) + queries = [q for q, _ in backend.calls] + assert not any(q.endswith(" 2024") for q in queries), queries + + +def test_year_hint_not_double_appended_if_already_present(): + # If the LLM's sub-query already mentions the year, don't append it again. + class _LLM: + def __init__(self): + self._n = 0 + + def generate_json(self, prompt: str) -> dict: + self._n += 1 + if self._n == 1: + return {"question_type": "outbreak_count"} + return { + "sub_queries": [ + {"text": "H5N1 cases 2024", "axis": "latest_data"}, + ] + } + + backend = _RecordingBackend() + backend.set_fallback( + [ + RawSearchResult( + url="https://news.example.com/a", + title="A", + snippet="", + rank=1, + published_date="2024-01-01", + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, llm_client=_LLM(), backend_name="fake" + ) + pipeline.run(_question(datetime(2024, 6, 1, tzinfo=timezone.utc))) + queries = [q for q, _ in backend.calls] + # Should NOT be "H5N1 cases 2024 2024" + assert all(q.count("2024") == 1 for q in queries), queries + + +def test_top_up_fires_when_survivors_below_threshold(): + """First round returns mostly post-cutoff (so few survive); top-up + round with bigger max_results returns extras that include pre-cutoff + items. The backend should be called once per sub-query per round.""" + as_of = datetime(2024, 6, 1, tzinfo=timezone.utc) + + # Build the per-query result sets. + round1 = [ + RawSearchResult( + url=f"https://news.example.com/post-{i}", + title="post", + snippet="", + rank=i, + published_date="2024-09-01", # post-cutoff + ) + for i in range(3) + ] + round2 = round1 + [ + RawSearchResult( + url=f"https://news.example.com/pre-{i}", + title="pre", + snippet="", + rank=i + 10, + published_date="2024-01-01", # pre-cutoff + ) + for i in range(20) + ] + + backend = _RecordingBackend() + # Three sub-queries from _FakeLLM each get year-hinted: + for query in ("H5N1 cases 2024", "avian flu trend 2024", "bird flu policy 2024"): + backend._results[(query, 10)] = round1 + backend._results[(query, 50)] = round2 + + pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=_FakeLLM(), + backend_name="fake", + min_post_filter_results=10, + top_up_results_per_query=50, + max_top_up_rounds=1, + ) + results = pipeline.run(_question(as_of)) + + # Each sub-query should have been called twice: once with max=10 and once + # with max=50. + max_results_seen = [m for _, m in backend.calls] + assert 10 in max_results_seen + assert 50 in max_results_seen + # After top-up we should have well over the threshold of pre-cutoff results. + assert len(results) >= 10 + + +def test_top_up_skipped_when_survivors_meet_threshold(): + """If the initial round already returns enough survivors, no top-up.""" + as_of = datetime(2024, 6, 1, tzinfo=timezone.utc) + plenty = [ + RawSearchResult( + url=f"https://news.example.com/x-{i}", + title=f"x{i}", + snippet="", + rank=i, + published_date="2024-01-01", + ) + for i in range(20) + ] + backend = _RecordingBackend() + backend.set_fallback(plenty) + pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=_FakeLLM(), + backend_name="fake", + min_post_filter_results=10, + top_up_results_per_query=50, + max_top_up_rounds=1, + ) + pipeline.run(_question(as_of)) + # Only the initial round (max=10) should have fired. + max_results_seen = {m for _, m in backend.calls} + assert max_results_seen == {10} + + +def test_top_up_skipped_in_live_mode(): + """Live mode never tops up, even when result count is low.""" + backend = _RecordingBackend() + backend.set_fallback( + [ + RawSearchResult( + url="https://news.example.com/only", + title="only", + snippet="", + rank=1, + published_date=None, + ) + ] + ) + pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=_FakeLLM(), + backend_name="fake", + min_post_filter_results=10, + top_up_results_per_query=50, + max_top_up_rounds=2, + ) + pipeline.run(_question(as_of=None)) + max_results_seen = {m for _, m in backend.calls} + assert max_results_seen == {10} diff --git a/bioscancast/tests/test_search_filtering_integration.py b/bioscancast/tests/test_search_filtering_integration.py index 5dd7405..28b8d1c 100644 --- a/bioscancast/tests/test_search_filtering_integration.py +++ b/bioscancast/tests/test_search_filtering_integration.py @@ -20,7 +20,9 @@ class RealisticFakeSearchBackend: """Returns results with titles/snippets that overlap with the H5N1 question, simulating what a real search engine would return.""" - def search(self, query: str, max_results: int = 10) -> List[RawSearchResult]: + def search( + self, query: str, max_results: int = 10, end_date=None, start_date=None + ) -> List[RawSearchResult]: return [ RawSearchResult( url="https://www.cdc.gov/bird-flu/situation-summary/", diff --git a/bioscancast/tests/test_search_pipeline.py b/bioscancast/tests/test_search_pipeline.py index bc53723..907d6a4 100644 --- a/bioscancast/tests/test_search_pipeline.py +++ b/bioscancast/tests/test_search_pipeline.py @@ -78,7 +78,9 @@ def _default_results() -> List[RawSearchResult]: ), ] - def search(self, query: str, max_results: int = 10) -> List[RawSearchResult]: + def search( + self, query: str, max_results: int = 10, end_date=None, start_date=None + ) -> List[RawSearchResult]: self.queries_received.append(query) return self._results diff --git a/bioscancast/tests/test_tavily_backend.py b/bioscancast/tests/test_tavily_backend.py new file mode 100644 index 0000000..b80842e --- /dev/null +++ b/bioscancast/tests/test_tavily_backend.py @@ -0,0 +1,83 @@ +"""Unit tests for TavilyBackend's date-window forwarding. + +The Tavily news endpoint silently ignores ``end_date`` unless ``start_date`` +is also passed (verified 2026-05-20, see +``specs/tavily-investigation-findings.md``). The backend's job is to forward +the pair when both are present, drop ``end_date`` alone with a warning, +and call the SDK with no date params otherwise. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from bioscancast.stages.search_stage.backends.tavily_backend import TavilyBackend + + +class _FakeTavilyClient: + """Captures the kwargs of every ``search`` call so tests can assert on them.""" + + def __init__(self, *_args, **_kwargs): + self.calls: list[dict[str, Any]] = [] + _FakeTavilyClient.last_instance = self + + def search(self, **kwargs): + self.calls.append(kwargs) + return {"results": []} + + +@pytest.fixture +def fake_tavily(monkeypatch): + """Patch tavily.TavilyClient so no network call is made.""" + import tavily + + monkeypatch.setattr(tavily, "TavilyClient", _FakeTavilyClient) + yield _FakeTavilyClient + + +def test_forwards_start_and_end_date_pair(fake_tavily): + backend = TavilyBackend(api_key="test-key") + backend.search( + "H5N1 cases", max_results=5, + start_date="2024-01-01", end_date="2025-02-17", + ) + call = fake_tavily.last_instance.calls[-1] + assert call["start_date"] == "2024-01-01" + assert call["end_date"] == "2025-02-17" + assert call["topic"] == "news" + assert call["max_results"] == 5 + + +def test_drops_end_date_when_start_date_missing(fake_tavily, caplog): + """Tavily ignores end_date alone — sending it would mislead anyone reading + the request log. The backend logs a warning and omits both.""" + backend = TavilyBackend(api_key="test-key") + with caplog.at_level("WARNING"): + backend.search("Mpox cases", end_date="2025-02-17") + call = fake_tavily.last_instance.calls[-1] + assert "end_date" not in call + assert "start_date" not in call + assert any("end_date" in rec.message and "start_date" in rec.message + for rec in caplog.records), ( + "expected a warning when end_date is passed without start_date" + ) + + +def test_no_date_params_in_live_mode(fake_tavily): + backend = TavilyBackend(api_key="test-key") + backend.search("H5N1 cases") + call = fake_tavily.last_instance.calls[-1] + assert "start_date" not in call + assert "end_date" not in call + assert call["topic"] == "news" + + +def test_start_date_without_end_date_is_also_dropped(fake_tavily): + """The pair must be complete; lone start_date is also ignored upstream.""" + backend = TavilyBackend(api_key="test-key") + backend.search("H5N1 cases", start_date="2024-01-01") + call = fake_tavily.last_instance.calls[-1] + assert "start_date" not in call + assert "end_date" not in call diff --git a/bioscancast/tests/test_wayback_fetch.py b/bioscancast/tests/test_wayback_fetch.py new file mode 100644 index 0000000..32b220a --- /dev/null +++ b/bioscancast/tests/test_wayback_fetch.py @@ -0,0 +1,115 @@ +"""Offline tests for the Wayback rewrite in the extraction fetcher. + +The patching reaches into ``bioscancast.extraction.fetcher.closest_snapshot_before`` +(the symbol imported at module load) and ``curl_requests.get``, never touching +the network. There is also a ``@pytest.mark.live`` smoke test for hitting +Wayback for real. +""" + +from datetime import datetime, timezone +from unittest.mock import patch + +import pytest + +from bioscancast.extraction.fetcher import fetch + + +class _FakeResponse: + def __init__(self, *, body: bytes, url: str, status: int = 200): + self.status_code = status + self.headers = {"content-type": "text/html"} + self.url = url + self._body = body + + def iter_content(self): + yield self._body + + def close(self): + pass + + +def _patch_curl(body: bytes, url: str = "https://example.com/page"): + return patch( + "bioscancast.extraction.fetcher.curl_requests.get", + return_value=_FakeResponse(body=body, url=url), + ) + + +def _patch_snapshot(value): + return patch( + "bioscancast.extraction.fetcher.closest_snapshot_before", + return_value=value, + ) + + +class TestWaybackRewrite: + def test_live_mode_no_wayback_call(self): + with _patch_curl(b"live") as mock_get, _patch_snapshot(None) as mock_snap: + result = fetch("https://example.com/page", as_of_date=None) + assert result.fetch_strategy == "live" + assert result.snapshot_timestamp is None + mock_snap.assert_not_called() + mock_get.assert_called_once() + + def test_wayback_success(self): + snap_dt = datetime(2024, 3, 1, 12, 0, 0, tzinfo=timezone.utc) + snap_url = "https://web.archive.org/web/20240301120000id_/https://example.com/page" + with _patch_snapshot((snap_dt, snap_url)), _patch_curl(b"snapshot"): + result = fetch( + "https://example.com/page", + as_of_date=datetime(2024, 6, 1, tzinfo=timezone.utc), + ) + assert result.fetch_strategy == "wayback" + assert result.snapshot_timestamp == snap_dt + assert result.url == "https://example.com/page" # original, not archive.org + assert result.content_bytes == b"snapshot" + + def test_no_snapshot_falls_back_to_live(self): + with _patch_snapshot(None), _patch_curl(b"live"): + result = fetch( + "https://example.com/page", + as_of_date=datetime(2024, 6, 1, tzinfo=timezone.utc), + ) + assert result.fetch_strategy == "wayback_fallback_to_live" + assert result.snapshot_timestamp is None + assert result.url == "https://example.com/page" + + def test_wayback_fetch_error_falls_back_to_live(self): + snap_dt = datetime(2024, 3, 1, tzinfo=timezone.utc) + snap_url = "https://web.archive.org/web/20240301120000id_/https://example.com/page" + # First call (to Wayback) errors; second call (live) succeeds. + responses = [ + ConnectionError("wayback down"), + _FakeResponse(body=b"live", url="https://example.com/page"), + ] + + def side_effect(*args, **kwargs): + r = responses.pop(0) + if isinstance(r, Exception): + raise r + return r + + with _patch_snapshot((snap_dt, snap_url)), patch( + "bioscancast.extraction.fetcher.curl_requests.get", side_effect=side_effect + ): + result = fetch( + "https://example.com/page", + as_of_date=datetime(2024, 6, 1, tzinfo=timezone.utc), + ) + assert result.fetch_strategy == "wayback_fallback_to_live" + assert result.content_bytes == b"live" + + +@pytest.mark.live +def test_live_wayback_lookup(): + """Smoke-test the real Wayback CDX endpoint. Skipped by default.""" + from bioscancast.stages.search_stage.wayback import closest_snapshot_before + + result = closest_snapshot_before( + "https://www.cdc.gov/", + datetime(2023, 1, 1, tzinfo=timezone.utc), + ) + assert result is not None + snap_dt, snap_url = result + assert snap_dt < datetime(2023, 1, 2, tzinfo=timezone.utc) + assert "web.archive.org/web/" in snap_url diff --git a/bioscancast/tests/test_wayback_retry.py b/bioscancast/tests/test_wayback_retry.py new file mode 100644 index 0000000..bb8b8b5 --- /dev/null +++ b/bioscancast/tests/test_wayback_retry.py @@ -0,0 +1,150 @@ +"""Retry/backoff behavior for the Wayback CDX client.""" + +from __future__ import annotations + +import socket +import urllib.error +from io import BytesIO +from unittest.mock import patch + +from bioscancast.stages.search_stage import wayback + + +def _http_error(code: int) -> urllib.error.HTTPError: + return urllib.error.HTTPError( + url="https://web.archive.org/cdx/search/cdx", + code=code, + msg=str(code), + hdrs=None, # type: ignore[arg-type] + fp=None, + ) + + +def _ok_response(payload: bytes): + """A minimal stand-in for the context manager returned by urlopen.""" + + class _CM: + def __enter__(self): + return BytesIO(payload) + + def __exit__(self, *a): + return False + + return _CM() + + +class TestCdxRetry: + def _no_sleep(self): + return patch.object(wayback, "_sleep", lambda _s: None) + + def _short_schedule(self): + # 3 attempts max so tests are predictable; all delays are no-ops. + return patch.object(wayback, "RETRY_BACKOFF_SECONDS", (0, 0, 0)) + + def test_retries_then_succeeds_on_503(self): + # First two calls 503, third returns valid JSON. + seq = [ + _http_error(503), + _http_error(503), + _ok_response(b'[["urlkey","timestamp","original"],["a","20240101120000","b"]]'), + ] + with self._short_schedule(), self._no_sleep(), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=seq, + ): + data = wayback._cdx_query({"url": "https://example.com/"}) + assert data is not None + assert data == [["a", "20240101120000", "b"]] + + def test_gives_up_after_max_attempts_503(self): + with self._short_schedule(), self._no_sleep(), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=[_http_error(503)] * 3, + ): + data = wayback._cdx_query({"url": "https://example.com/"}) + assert data is None + + def test_retries_on_timeout(self): + seq = [ + socket.timeout("read timeout"), + _ok_response(b'[["urlkey","timestamp","original"]]'), + ] + with self._short_schedule(), self._no_sleep(), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=seq, + ): + data = wayback._cdx_query({"url": "https://example.com/"}) + # Header-only payload → empty rows list + assert data == [] + + def test_non_recoverable_status_does_not_retry(self): + # 404 should fail immediately with no retries + with self._short_schedule(), self._no_sleep(), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=[_http_error(404)], + ) as mock_open: + data = wayback._cdx_query({"url": "https://example.com/"}) + assert data is None + assert mock_open.call_count == 1 + + def test_recoverable_statuses_cover_5xx_and_429(self): + # 429 is rate-limit; should be treated as recoverable. + seq = [ + _http_error(429), + _ok_response(b'[["header"]]'), + ] + with self._short_schedule(), self._no_sleep(), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=seq, + ): + data = wayback._cdx_query({"url": "https://example.com/"}) + assert data == [] + + +class TestCdxThrottle: + """Proactive min-interval pacing in front of every urlopen.""" + + def test_throttle_paces_successive_calls(self): + sleep_calls: list[float] = [] + ok = b'[["urlkey","timestamp","original"],["a","20240101120000","b"]]' + with patch.object(wayback, "_last_call_monotonic", 0.0), patch.object( + wayback, "_min_interval_seconds", lambda: 5.0 + ), patch.object(wayback, "_sleep", lambda s: sleep_calls.append(s)), patch.object( + wayback, "RETRY_BACKOFF_SECONDS", (0,) + ), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=[_ok_response(ok), _ok_response(ok)], + ): + wayback._cdx_query({"url": "https://example.com/a"}) + wayback._cdx_query({"url": "https://example.com/b"}) + positive_waits = [s for s in sleep_calls if s > 0] + assert len(positive_waits) == 1 + assert 4.0 < positive_waits[0] <= 5.0 + + def test_throttle_fires_before_each_retry(self): + # Throttle paces before every urlopen — including retried ones — so a + # 503 → OK sequence yields two _throttle() calls, the second of which + # sleeps because the first urlopen just bumped _last_call_monotonic. + sleep_calls: list[float] = [] + ok = b'[["urlkey","timestamp","original"]]' + with patch.object(wayback, "_last_call_monotonic", 0.0), patch.object( + wayback, "_min_interval_seconds", lambda: 3.0 + ), patch.object(wayback, "_sleep", lambda s: sleep_calls.append(s)), patch.object( + wayback, "RETRY_BACKOFF_SECONDS", (0, 0, 0) + ), patch( + "bioscancast.stages.search_stage.wayback.urllib.request.urlopen", + side_effect=[_http_error(503), _ok_response(ok)], + ): + data = wayback._cdx_query({"url": "https://example.com/"}) + assert data == [] + positive_waits = [s for s in sleep_calls if s > 0] + assert len(positive_waits) == 1 + assert 2.0 < positive_waits[0] <= 3.0 + + def test_min_interval_env_override(self, monkeypatch): + monkeypatch.setenv("BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS", "1.5") + assert wayback._min_interval_seconds() == 1.5 + monkeypatch.setenv("BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS", "not-a-number") + assert wayback._min_interval_seconds() == wayback._DEFAULT_MIN_INTERVAL_SECONDS + monkeypatch.delenv("BIOSCANCAST_WAYBACK_MIN_INTERVAL_SECONDS", raising=False) + assert wayback._min_interval_seconds() == wayback._DEFAULT_MIN_INTERVAL_SECONDS diff --git a/scripts/analyze_tavily_probe.py b/scripts/analyze_tavily_probe.py new file mode 100644 index 0000000..7d75764 --- /dev/null +++ b/scripts/analyze_tavily_probe.py @@ -0,0 +1,241 @@ +"""Consolidate Tavily probe-results JSON dumps into a hit-rate table. + +Reads every ``specs/probe-results/*.json`` produced by ``probe_tavily_topic.py``, +applies the production cutoff filter + URL-slug date recovery, and prints a +markdown table suitable for pasting into the findings doc. + +Also computes a "hybrid" row per question_id: union of news + general results +under matching knobs, deduped by URL. + +No network calls. Safe to re-run any time. +""" + +from __future__ import annotations + +import json +import os +import sys +from collections import defaultdict +from datetime import date, datetime +from pathlib import Path +from typing import Any, Iterable + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from bioscancast.stages.search_stage.date_recovery import date_from_url_slug + +REPO_ROOT = Path(__file__).resolve().parent.parent +RESULTS_DIR = REPO_ROOT / "specs" / "probe-results" + + +def parse_published_date(dstr: str | None) -> date | None: + if not dstr: + return None + try: + return date.fromisoformat(dstr[:10]) + except ValueError: + pass + try: + from email.utils import parsedate_to_datetime + return parsedate_to_datetime(dstr).date() + except (ValueError, TypeError): + return None + + +def classify_one(result: dict[str, Any], cutoff: date) -> dict[str, Any]: + url = result.get("url", "") + raw = result.get("published_date") + pd = parse_published_date(raw) + slug = date_from_url_slug(url) + slug_d = slug.date() if slug else None + # "effective" date: prefer native published_date, fall back to slug. + effective = pd or slug_d + return { + "url": url, + "title": result.get("title", ""), + "raw_published_date": raw, + "parsed_published_date": pd.isoformat() if pd else None, + "slug_date": slug_d.isoformat() if slug_d else None, + "effective_date": effective.isoformat() if effective else None, + "native_pre_cutoff": pd is not None and pd <= cutoff, + "effective_pre_cutoff": effective is not None and effective <= cutoff, + "native_dated": pd is not None, + "effective_dated": effective is not None, + } + + +def analyze_payload(payload: dict[str, Any]) -> dict[str, Any]: + cutoff = date.fromisoformat(payload["cutoff"]) + results = payload["response"].get("results", []) or [] + classified = [classify_one(r, cutoff) for r in results] + n = len(classified) or 1 + return { + "tag": payload["tag"], + "query": payload["query"], + "cutoff": payload["cutoff"], + "knobs": payload["knobs"], + "n_results": len(classified), + "native_pre_cutoff": sum(1 for c in classified if c["native_pre_cutoff"]), + "native_dated": sum(1 for c in classified if c["native_dated"]), + "effective_pre_cutoff": sum(1 for c in classified if c["effective_pre_cutoff"]), + "effective_dated": sum(1 for c in classified if c["effective_dated"]), + "results": classified, + "fetched_at": payload.get("fetched_at"), + } + + +def knob_summary(knobs: dict[str, Any]) -> str: + """Compact human-readable summary of the non-default knobs.""" + parts = [] + topic = knobs.get("topic", "news") + parts.append(topic) + for k, v in sorted(knobs.items()): + if k in {"topic", "max_results", "include_answer"}: + continue + if k == "include_domains": + parts.append(f"domains={len(v)}") + else: + parts.append(f"{k}={v}") + return " ".join(parts) + + +def load_all() -> list[dict[str, Any]]: + out = [] + if not RESULTS_DIR.exists(): + return out + for path in sorted(RESULTS_DIR.glob("*.json")): + with path.open(encoding="utf-8") as f: + payload = json.load(f) + out.append(analyze_payload(payload)) + return out + + +def emit_table(rows: Iterable[dict[str, Any]]) -> str: + rows = list(rows) + lines = [] + header = "| tag | config | n | native pre/dated | + slug pre/dated |" + sep = "|---|---|---|---|---|" + lines.append(header) + lines.append(sep) + for r in rows: + cfg = knob_summary(r["knobs"]) + native = f"{r['native_pre_cutoff']}/{r['native_dated']}" + eff = f"{r['effective_pre_cutoff']}/{r['effective_dated']}" + lines.append(f"| {r['tag']} | {cfg} | {r['n_results']} | {native} | {eff} |") + return "\n".join(lines) + + +def compute_hybrid(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + """For each tag, find the news-topic and general-topic rows under otherwise + matching knobs and produce a unioned hybrid row.""" + by_tag: dict[str, dict[str, list[dict[str, Any]]]] = defaultdict(lambda: defaultdict(list)) + for r in rows: + topic = r["knobs"].get("topic", "news") + # Match by (tag, non-topic-knobs); store both topic variants. + non_topic = {k: v for k, v in r["knobs"].items() if k != "topic"} + key = json.dumps(non_topic, sort_keys=True) + by_tag[r["tag"]][key].append(r) + + hybrid_rows = [] + for tag, by_knobs in by_tag.items(): + for key, group in by_knobs.items(): + if len({r["knobs"].get("topic") for r in group}) < 2: + continue + news = [r for r in group if r["knobs"].get("topic") == "news"] + general = [r for r in group if r["knobs"].get("topic") == "general"] + if not news or not general: + continue + news_r = news[0] + general_r = general[0] + seen_urls: set[str] = set() + unioned = [] + for src in (news_r, general_r): + for c in src["results"]: + if c["url"] in seen_urls: + continue + seen_urls.add(c["url"]) + unioned.append(c) + native_pre = sum(1 for c in unioned if c["native_pre_cutoff"]) + native_dated = sum(1 for c in unioned if c["native_dated"]) + eff_pre = sum(1 for c in unioned if c["effective_pre_cutoff"]) + eff_dated = sum(1 for c in unioned if c["effective_dated"]) + cfg_knobs = {**json.loads(key), "topic": "hybrid(news+general)"} + hybrid_rows.append({ + "tag": tag, + "query": news_r["query"], + "cutoff": news_r["cutoff"], + "knobs": cfg_knobs, + "n_results": len(unioned), + "native_pre_cutoff": native_pre, + "native_dated": native_dated, + "effective_pre_cutoff": eff_pre, + "effective_dated": eff_dated, + "results": unioned, + }) + return hybrid_rows + + +def print_url_slug_coverage(rows: list[dict[str, Any]]) -> None: + """Audit: for general-mode rows with no native dates, what fraction of URLs + yield a date via the slug regex?""" + print("\n## URL-slug recovery coverage (general-mode, no native date)\n") + print("| tag | knobs | undated_urls | slug_recovered | recovery_rate |") + print("|---|---|---|---|---|") + for r in rows: + if r["knobs"].get("topic") != "general": + continue + undated = [c for c in r["results"] if not c["native_dated"]] + recovered = [c for c in undated if c["slug_date"] is not None] + if not undated: + continue + print( + f"| {r['tag']} | {knob_summary(r['knobs'])} | {len(undated)} | " + f"{len(recovered)} | {len(recovered) / len(undated):.0%} |" + ) + + +def print_undated_url_sample(rows: list[dict[str, Any]], n: int = 30) -> None: + """For Phase E: list a sample of undated, slug-non-matching URLs so we can + eyeball what patterns Tavily-general returns.""" + print("\n## Undated URLs that the slug regex does NOT catch (sample)\n") + seen: set[str] = set() + count = 0 + for r in rows: + if r["knobs"].get("topic") != "general": + continue + for c in r["results"]: + if c["native_dated"] or c["slug_date"] is not None: + continue + if c["url"] in seen: + continue + seen.add(c["url"]) + print(f"- [{r['tag']}] {c['url']}") + count += 1 + if count >= n: + return + + +def main() -> None: + rows = load_all() + if not rows: + print("No probe-results/*.json found. Run probe_tavily_topic.py first.") + return + + print(f"# Tavily probe analysis ({len(rows)} runs)\n") + print("## All runs\n") + print(emit_table(rows)) + + hybrid = compute_hybrid(rows) + if hybrid: + print("\n## Hybrid (news+general union)\n") + print(emit_table(hybrid)) + + print_url_slug_coverage(rows) + print_undated_url_sample(rows) + + # Total call count = number of payloads (one Tavily call per cache entry). + print(f"\n_Total cached Tavily calls: {len(rows)}_") + + +if __name__ == "__main__": + main() diff --git a/scripts/probe_tavily_topic.py b/scripts/probe_tavily_topic.py new file mode 100644 index 0000000..2cd4648 --- /dev/null +++ b/scripts/probe_tavily_topic.py @@ -0,0 +1,245 @@ +"""Probe Tavily configurations across the BioScanCast resolved corpus. + +Originally a single-query script comparing topic="news" vs topic="general" on +q1 (H5N1 US, cutoff Feb 17 2025). Now generalized to iterate the corpus and +explore Tavily knobs (search_depth, include_domains, exact_match, etc.) under +the historical-replay cutoff machinery. + +Investigation context: see ``specs/tavily-historical-coverage.md`` and the +plan at ``~/.claude/plans/i-d-like-you-to-wondrous-whale.md``. + +Each (question x config) result is dumped to ``specs/probe-results/`` as JSON +so the analyzer (``analyze_tavily_probe.py``) can re-compute hit rates and +date-recovery coverage offline without re-paying the Tavily quota. + +Examples: + # All resolved questions, news topic, default settings + python scripts/probe_tavily_topic.py --question-id all --topic news + + # Single question, advanced search_depth + python scripts/probe_tavily_topic.py --question-id q1 --topic news \ + --knobs '{"search_depth": "advanced"}' + + # Synthetic backdated query (override question text + cutoff) + python scripts/probe_tavily_topic.py --synthetic-query "MERS-CoV cases Saudi Arabia 2015" \ + --synthetic-cutoff 2017-01-01 --synthetic-tag mers2015 --topic news + + # Original q1/news+general behavior (legacy) + python scripts/probe_tavily_topic.py --legacy +""" + +from __future__ import annotations + +import argparse +import csv +import hashlib +import json +import os +import sys +from collections import Counter +from datetime import date, datetime, timedelta, timezone +from pathlib import Path +from typing import Any + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + pass + +from tavily import TavilyClient + + +REPO_ROOT = Path(__file__).resolve().parent.parent +CORPUS_CSV = REPO_ROOT / "bioscancast" / "stages" / "eval_stage" / "bioscancast_questions.csv" +RESULTS_DIR = REPO_ROOT / "specs" / "probe-results" + + +def excel_serial_to_date(serial: int | str) -> date: + """Excel epoch is 1899-12-30 (Lotus 1-2-3 leap-year bug correction).""" + return (datetime(1899, 12, 30) + timedelta(days=int(serial))).date() + + +def load_resolved_questions() -> list[dict[str, Any]]: + out: list[dict[str, Any]] = [] + with CORPUS_CSV.open(encoding="utf-8") as f: + reader = csv.DictReader(f, delimiter=";") + for row in reader: + if row.get("question_status") == "resolved": + row["cutoff_date"] = excel_serial_to_date(row["created_date"]) + out.append(row) + return out + + +def get_question(qid: str) -> dict[str, Any]: + for q in load_resolved_questions(): + if q["question_id"] == qid: + return q + raise SystemExit(f"Unknown or unresolved question_id: {qid}") + + +def _bucket(dstr: str | None, cutoff: date) -> str: + if not dstr: + return "no_date" + try: + d = date.fromisoformat(dstr[:10]) + except ValueError: + try: + from email.utils import parsedate_to_datetime + d = parsedate_to_datetime(dstr).date() + except (ValueError, TypeError): + return "unparseable" + if d <= cutoff: + return "pre_cutoff" + return f"post_cutoff_{d.year}" + + +def config_hash(query: str, cutoff: date, knobs: dict[str, Any]) -> str: + payload = json.dumps({"query": query, "cutoff": cutoff.isoformat(), "knobs": knobs}, sort_keys=True) + return hashlib.sha1(payload.encode()).hexdigest()[:10] + + +def cache_path(tag: str, knobs: dict[str, Any]) -> Path: + """Filename: __.json. Tag is question_id or synthetic-tag.""" + knob_summary = "_".join(f"{k}={v}" for k, v in sorted(knobs.items()) if k != "include_domains") + if "include_domains" in knobs: + knob_summary += "_domains=" + str(len(knobs["include_domains"])) + knob_summary = knob_summary.replace("/", "_").replace(":", "_")[:60] or "default" + h = hashlib.sha1(json.dumps(knobs, sort_keys=True).encode()).hexdigest()[:8] + return RESULTS_DIR / f"{tag}__{knob_summary}__{h}.json" + + +def run_probe( + client: TavilyClient, + *, + tag: str, + query: str, + cutoff: date, + knobs: dict[str, Any], + force: bool = False, +) -> dict[str, Any]: + """Run one Tavily call (cached). Returns the cached payload.""" + RESULTS_DIR.mkdir(parents=True, exist_ok=True) + path = cache_path(tag, knobs) + if path.exists() and not force: + with path.open(encoding="utf-8") as f: + return json.load(f) + + kwargs: dict[str, Any] = {"query": query, "include_answer": False, **knobs} + if "max_results" not in kwargs: + kwargs["max_results"] = 20 + resp = client.search(**kwargs) + + payload = { + "tag": tag, + "query": query, + "cutoff": cutoff.isoformat(), + "knobs": knobs, + "fetched_at": datetime.now(timezone.utc).isoformat(), + "response": resp, + } + with path.open("w", encoding="utf-8") as f: + json.dump(payload, f, indent=2) + return payload + + +def summarize(payload: dict[str, Any]) -> None: + cutoff = date.fromisoformat(payload["cutoff"]) + results = payload["response"].get("results", []) or [] + buckets: Counter = Counter() + dated = 0 + for r in results: + d = r.get("published_date") + if d: + dated += 1 + buckets[_bucket(d, cutoff)] += 1 + pre = buckets.get("pre_cutoff", 0) + knob_str = ", ".join(f"{k}={v}" for k, v in sorted(payload["knobs"].items()))[:80] or "(default)" + n = len(results) or 1 + print( + f" {payload['tag']:>10} cutoff={cutoff} {knob_str:<82} " + f"-> pre={pre}/{len(results)} ({pre / n:.0%}) dated={dated}/{len(results)}" + ) + + +def add_year_hint(query: str, cutoff: date) -> str: + """Mirror the pipeline's year-hint suffix so probes match pipeline behavior.""" + y = str(cutoff.year) + if y in query: + return query + return f"{query} {y}" + + +def build_query_from_question(q: dict[str, Any], hint_year: bool = True) -> str: + """Construct a search query from a corpus question. Strip framing words + ("How many ... will be reported ... according to ...") to expose the + topical noun phrase. Keep the topic prefix as a hint.""" + text = q["question_text"] + base = f"{q['topic']} {text}" + return add_year_hint(base, q["cutoff_date"]) if hint_year else base + + +def parse_args(argv: list[str]) -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--question-id", help="Resolved question id (q1, q3, q7, q9) or 'all'.") + p.add_argument("--topic", choices=["news", "general", "finance"], default="news") + p.add_argument("--knobs", default="{}", help="JSON object of extra Tavily kwargs.") + p.add_argument("--synthetic-query", help="Use a free-form query instead of corpus.") + p.add_argument("--synthetic-cutoff", help="YYYY-MM-DD cutoff for synthetic query.") + p.add_argument("--synthetic-tag", help="Short tag for cache filename (synthetic only).") + p.add_argument("--no-year-hint", action="store_true", help="Skip the year-suffix hint.") + p.add_argument("--force", action="store_true", help="Bypass cache and re-call Tavily.") + p.add_argument("--legacy", action="store_true", help="Replicate original q1 news+general behavior.") + return p.parse_args(argv) + + +def main(argv: list[str] | None = None) -> None: + args = parse_args(argv or sys.argv[1:]) + api_key = os.environ.get("TAVILY_API_KEY") + if not api_key: + sys.exit("TAVILY_API_KEY missing") + client = TavilyClient(api_key=api_key) + + if args.legacy: + q = get_question("q1") + query = build_query_from_question(q, hint_year=False) + cutoff = q["cutoff_date"] + for topic in ("news", "general"): + payload = run_probe( + client, tag="q1_legacy", query=query, cutoff=cutoff, + knobs={"topic": topic, "max_results": 20}, force=args.force, + ) + summarize(payload) + return + + knobs = json.loads(args.knobs) + knobs.setdefault("topic", args.topic) + knobs.setdefault("max_results", 20) + + if args.synthetic_query: + if not args.synthetic_cutoff: + sys.exit("--synthetic-cutoff required with --synthetic-query") + tag = args.synthetic_tag or "synth" + cutoff = date.fromisoformat(args.synthetic_cutoff) + query = args.synthetic_query if args.no_year_hint else add_year_hint(args.synthetic_query, cutoff) + payload = run_probe(client, tag=tag, query=query, cutoff=cutoff, knobs=knobs, force=args.force) + summarize(payload) + return + + if not args.question_id: + sys.exit("provide --question-id, --synthetic-query, or --legacy") + + qids = ["q1", "q3", "q7", "q9"] if args.question_id == "all" else [args.question_id] + for qid in qids: + q = get_question(qid) + query = build_query_from_question(q, hint_year=not args.no_year_hint) + payload = run_probe( + client, tag=qid, query=query, cutoff=q["cutoff_date"], knobs=knobs, force=args.force, + ) + summarize(payload) + + +if __name__ == "__main__": + main() diff --git a/scripts/test_historical_replay.py b/scripts/test_historical_replay.py new file mode 100644 index 0000000..f095bc3 --- /dev/null +++ b/scripts/test_historical_replay.py @@ -0,0 +1,198 @@ +"""Manual smoke test for historical-replay mode. + +Runs the search stage against `q1` (resolved H5N1 US, Feb 28 2025 deadline) +with as_of_date = question.created_date. Prints a digest of what the cutoff +machinery did. Does NOT push through filtering/extraction (issue #13 would +make that uninformative without an LLM and a relaxed threshold). + +What this validates on the feat/as-of-date-replay branch: + - Tavily backend receives end_date matching the cutoff + - All returned SearchResult.published_date <= as_of_date + - Dashboard URLs are Wayback-rewritten (or suppressed if no snapshot) + - SearchResult.cutoff_applied is populated + - The date-recovery chain fires for undated results + +What this also gathers (for issue #5 — Tavily date reliability): + - share of Tavily results that came with a published_date + - share that needed the recovery chain (and which strategy won) + - share that were dropped for no-date-available + +Run: + python scripts/test_historical_replay.py + +Requires TAVILY_API_KEY and OPENAI_API_KEY in environment (or .env). +""" + +from __future__ import annotations + +import logging +import os +import sys +from collections import Counter +from datetime import datetime, timezone + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + pass + +from bioscancast.filtering.models import ForecastQuestion +from bioscancast.llm.client import OpenAIClient +from bioscancast.stages.search_stage.backends.tavily_backend import TavilyBackend +from bioscancast.stages.search_stage.pipeline import SearchStagePipeline + + +# q1 from bioscancast_questions.csv. created_date is Excel serial 45705 = +# 2025-02-17. Resolution deadline is Feb 28, 2025. Cutoff for the human +# forecaster is the creation date. +Q1_TEXT = ( + "How many confirmed human cases of H5N1 will be reported in the US " + "by February 28, 2025, according to the US dashboard?" +) +Q1_CREATED = datetime(2025, 2, 17, tzinfo=timezone.utc) + + +def _hr(title: str) -> None: + bar = "=" * 72 + print(f"\n{bar}\n{title}\n{bar}") + + +def main() -> None: + # Surface the cutoff-filter and dashboard suppression INFO logs. + logging.basicConfig( + level=logging.INFO, + format="%(levelname)s %(name)s: %(message)s", + ) + # Hide curl_cffi/openai noise but keep our modules chatty. + for noisy in ("urllib3", "httpx", "openai", "tavily"): + logging.getLogger(noisy).setLevel(logging.WARNING) + + question = ForecastQuestion( + id="q1", + text=Q1_TEXT, + created_at=Q1_CREATED, + pathogen="h5n1", + region="United States", + target_date=datetime(2025, 2, 28, tzinfo=timezone.utc), + as_of_date=Q1_CREATED, # the cutoff + ) + + _hr("CONFIGURATION") + print(f"Question text : {question.text}") + print(f"Pathogen : {question.pathogen}") + print(f"Created at : {question.created_at.isoformat()}") + print(f"Target date : {question.target_date.isoformat()}") + print(f"AS-OF (cutoff) : {question.as_of_date.isoformat()}") + print(f"Historical mode : YES (as_of_date is set)") + + llm = OpenAIClient() + # Wrap the Tavily backend so we can observe whether end_date was forwarded. + base_backend = TavilyBackend() + end_dates_seen: list = [] + _orig_search = base_backend.search + + def wrapped_search(query: str, max_results: int = 10, end_date=None): + end_dates_seen.append(end_date) + return _orig_search(query, max_results=max_results, end_date=end_date) + + base_backend.search = wrapped_search # type: ignore[assignment] + + # NB: deliberately running without SearchCache so we hit Tavily fresh and + # the test isn't influenced by entries from a previous (different-cutoff) + # run. The cache key incorporates the cutoff so this is just paranoia. + pipeline = SearchStagePipeline( + search_backend=base_backend, + llm_client=llm, + cache=None, + backend_name="tavily", + # Leave historical_roleplay off — that's a separate opt-in. + ) + + _hr("RUNNING SEARCH STAGE") + results = pipeline.run(question) + + _hr("BACKEND OBSERVATIONS") + print(f"Sub-queries issued to Tavily: {len(end_dates_seen)}") + distinct_end_dates = set(end_dates_seen) + print(f"end_date values forwarded : {distinct_end_dates}") + if distinct_end_dates == {question.as_of_date.strftime('%Y-%m-%d')}: + print(">> end_date correctly forwarded on every call.") + else: + print(">> WARNING: end_date forwarding looks wrong.") + + _hr("RESULT SUMMARY") + print(f"Total results returned: {len(results)}") + + dashboards = [r for r in results if r.engine == "dashboard"] + organic = [r for r in results if r.engine != "dashboard"] + print(f" Dashboards : {len(dashboards)}") + print(f" Organic : {len(organic)}") + + # Cutoff sanity check + leaks = [r for r in results if r.published_date and r.published_date > question.as_of_date] + print(f" Post-cutoff leaks: {len(leaks)}") + if leaks: + for r in leaks: + print(f" LEAK: {r.url} pub={r.published_date.isoformat()}") + else: + print(" >> No post-cutoff results in the final list.") + + # cutoff_applied audit + bad_cutoff = [r for r in results if r.cutoff_applied != question.as_of_date] + print(f" cutoff_applied mismatches: {len(bad_cutoff)}") + + _hr("DASHBOARDS (Wayback rewrite or suppression)") + if not dashboards: + print("No dashboards present — Wayback either had no snapshot or " + "they were suppressed. Check the INFO log lines above.") + for r in dashboards: + in_wayback = "web.archive.org/web/" in r.url + snap_date = r.published_date.isoformat() if r.published_date else "n/a" + print(f" [{('WAYBACK' if in_wayback else 'LIVE!')}] {r.url}") + print(f" snapshot_date={snap_date} source={r.published_date_source}") + + _hr("DATA FOR ISSUE #5 (Tavily published_date reliability)") + source_counter: Counter = Counter(r.published_date_source for r in organic) + print("Per-result published_date_source distribution (organic only):") + for src, n in source_counter.most_common(): + label = src if src is not None else "" + print(f" {label:25s} {n}") + n_backend = source_counter.get("backend", 0) + n_recovered = sum( + n for src, n in source_counter.items() + if src in {"url_slug", "last_modified", "wayback_first_seen"} + ) + n_unsourced = source_counter.get(None, 0) + if organic: + print( + f"\nTavily-supplied date rate: {n_backend}/{len(organic)} = " + f"{n_backend / len(organic):.0%}" + ) + print( + f"Recovery-chain saves : {n_recovered}/{len(organic)} = " + f"{n_recovered / len(organic):.0%}" + ) + print( + f"Unsourced (kept anyway) : {n_unsourced}/{len(organic)} = " + f"{n_unsourced / len(organic):.0%} " + "[expected ~0 in historical mode after the filter]" + ) + + _hr("TOP 15 RESULTS (sorted by search_stage_score)") + for i, r in enumerate(results[:15], 1): + pub = r.published_date.date().isoformat() if r.published_date else "—" + src = r.published_date_source or "—" + print( + f"{i:2d}. score={r.search_stage_score:.3f} " + f"tier={r.source_tier:<13s} pub={pub:<10s} src={src:<20s} " + f"{r.domain}" + ) + print(f" {r.title[:90]}") + print(f" {r.url}") + + +if __name__ == "__main__": + main()