From c1bd01c9e8d87fbf2bafc99138f2ee1d224e55cd Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 13 May 2026 12:20:43 +0000 Subject: [PATCH] Add edge-predicate regression check framework Introduces test_harness/regression_checks/, a pluggable framework for TRAPI-response regression checks that run alongside the existing acceptance/pathfinder pass-fail analysis without overriding genuine failures. The first check (EdgePredicateMatchCheck) verifies that returned knowledge-graph predicates are TRAPI-compatible with the query graph (the queried predicate or a biolink descendant of it), catching cases where the right output curie is reached via the wrong-shape edge. Mismatches surface as a new AgentStatus.REGRESSION so reports can distinguish "wrong answer" from "wrong-shaped answer". biolink-toolkit init failure degrades to SKIPPED rather than crashing the run. https://claude.ai/code/session_012pDrWMNSK3p4cjJ2yp1FjF --- requirements.txt | 1 + test_harness/regression_checks/__init__.py | 21 +++ test_harness/regression_checks/base.py | 56 +++++++ .../regression_checks/edge_predicate.py | 122 ++++++++++++++ test_harness/result_collector.py | 1 + test_harness/run.py | 30 ++++ test_harness/utils.py | 11 +- tests/test_regression_checks.py | 153 ++++++++++++++++++ 8 files changed, 392 insertions(+), 3 deletions(-) create mode 100644 test_harness/regression_checks/__init__.py create mode 100644 test_harness/regression_checks/base.py create mode 100644 test_harness/regression_checks/edge_predicate.py create mode 100644 tests/test_regression_checks.py diff --git a/requirements.txt b/requirements.txt index f8dca51..2e529fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ slack_sdk==3.27.2 tqdm==4.66.4 translator-testing-model==0.5.0 reasoner-validator==4.2.5 +bmt>=1.4 diff --git a/test_harness/regression_checks/__init__.py b/test_harness/regression_checks/__init__.py new file mode 100644 index 0000000..38c81e8 --- /dev/null +++ b/test_harness/regression_checks/__init__.py @@ -0,0 +1,21 @@ +"""Regression checks run alongside acceptance/pathfinder pass-fail analysis.""" + +from test_harness.regression_checks.base import ( + CHECKS, + RegressionCheck, + RegressionCheckResult, + RegressionStatus, + run_all, +) +from test_harness.regression_checks.edge_predicate import EdgePredicateMatchCheck + +CHECKS.append(EdgePredicateMatchCheck()) + +__all__ = [ + "CHECKS", + "RegressionCheck", + "RegressionCheckResult", + "RegressionStatus", + "EdgePredicateMatchCheck", + "run_all", +] diff --git a/test_harness/regression_checks/base.py b/test_harness/regression_checks/base.py new file mode 100644 index 0000000..1803280 --- /dev/null +++ b/test_harness/regression_checks/base.py @@ -0,0 +1,56 @@ +"""Pluggable regression checks for TRAPI responses.""" + +from dataclasses import dataclass, field +from enum import Enum +import logging +from typing import Any, Dict, List, Optional, Protocol + + +class RegressionStatus(str, Enum): + PASSED = "PASSED" + FAILED = "FAILED" + SKIPPED = "SKIPPED" + + +@dataclass +class RegressionCheckResult: + name: str + status: RegressionStatus + message: Optional[str] = None + details: Optional[Dict[str, Any]] = field(default=None) + + +class RegressionCheck(Protocol): + name: str + + def run( + self, message: Dict[str, Any], query_graph: Dict[str, Any] + ) -> RegressionCheckResult: ... + + +CHECKS: List[RegressionCheck] = [] + + +def run_all( + message: Dict[str, Any], + query_graph: Dict[str, Any], + logger: Optional[logging.Logger] = None, +) -> List[RegressionCheckResult]: + """Run every registered regression check, isolating failures per-check.""" + results: List[RegressionCheckResult] = [] + for check in CHECKS: + try: + results.append(check.run(message, query_graph)) + except Exception as e: + if logger is not None: + logger.warning( + f"Regression check {getattr(check, 'name', type(check).__name__)} crashed: {e}" + ) + results.append( + RegressionCheckResult( + name=getattr(check, "name", type(check).__name__), + status=RegressionStatus.SKIPPED, + message=f"Check crashed: {type(e).__name__}: {e}", + ) + ) + return results diff --git a/test_harness/regression_checks/edge_predicate.py b/test_harness/regression_checks/edge_predicate.py new file mode 100644 index 0000000..b9b04dd --- /dev/null +++ b/test_harness/regression_checks/edge_predicate.py @@ -0,0 +1,122 @@ +"""Edge predicate regression check. + +Verifies that every predicate returned in the knowledge graph (and bound to a +query-graph edge) is the queried predicate itself or a biolink descendant of +it. A returned predicate that is an ancestor of the queried predicate (less +specific) or otherwise unrelated counts as a regression. +""" + +from typing import Any, Dict, List, Optional, Set + +from test_harness.regression_checks.base import ( + RegressionCheckResult, + RegressionStatus, +) + + +class EdgePredicateMatchCheck: + name = "edge_predicate_match" + + def __init__(self) -> None: + self._toolkit = None + self._toolkit_init_error: Optional[str] = None + self._descendants_cache: Dict[str, Set[str]] = {} + + def _get_toolkit(self): + if self._toolkit is not None or self._toolkit_init_error is not None: + return self._toolkit + try: + import bmt # imported lazily; bmt.Toolkit() is slow + self._toolkit = bmt.Toolkit() + except Exception as e: + self._toolkit_init_error = f"{type(e).__name__}: {e}" + return self._toolkit + + def _allowed_predicates(self, predicate: str) -> Set[str]: + cached = self._descendants_cache.get(predicate) + if cached is not None: + return cached + toolkit = self._get_toolkit() + if toolkit is None: + return set() + descendants = toolkit.get_descendants( + predicate, reflexive=False, formatted=True + ) or [] + allowed = {predicate, *descendants} + self._descendants_cache[predicate] = allowed + return allowed + + def run( + self, message: Dict[str, Any], query_graph: Dict[str, Any] + ) -> RegressionCheckResult: + qg_edges = (query_graph or {}).get("edges") or {} + qg_edges_with_predicates = { + edge_id: edge + for edge_id, edge in qg_edges.items() + if edge.get("predicates") + } + if not qg_edges_with_predicates: + return RegressionCheckResult( + name=self.name, + status=RegressionStatus.SKIPPED, + message="No query predicates to check.", + ) + + if self._get_toolkit() is None: + return RegressionCheckResult( + name=self.name, + status=RegressionStatus.SKIPPED, + message=f"biolink toolkit unavailable: {self._toolkit_init_error}", + ) + + allowed_by_edge: Dict[str, Set[str]] = {} + for edge_id, edge in qg_edges_with_predicates.items(): + allowed: Set[str] = set() + for predicate in edge["predicates"]: + allowed |= self._allowed_predicates(predicate) + allowed_by_edge[edge_id] = allowed + + kg_edges = (message.get("knowledge_graph") or {}).get("edges") or {} + results = message.get("results") or [] + mismatches: List[Dict[str, Any]] = [] + + for result_idx, result in enumerate(results): + for analysis_idx, analysis in enumerate(result.get("analyses") or []): + for qg_edge_id, bindings in (analysis.get("edge_bindings") or {}).items(): + allowed = allowed_by_edge.get(qg_edge_id) + if allowed is None: + continue + for binding in bindings or []: + kg_edge_id = binding.get("id") + kg_edge = kg_edges.get(kg_edge_id) + if kg_edge is None: + mismatches.append({ + "result_index": result_idx, + "analysis_index": analysis_idx, + "qg_edge_id": qg_edge_id, + "kg_edge_id": kg_edge_id, + "reason": "kg_edge_missing", + }) + continue + predicate = kg_edge.get("predicate") + if predicate not in allowed: + mismatches.append({ + "result_index": result_idx, + "analysis_index": analysis_idx, + "qg_edge_id": qg_edge_id, + "kg_edge_id": kg_edge_id, + "predicate": predicate, + "expected_predicates": sorted(allowed), + }) + + if mismatches: + return RegressionCheckResult( + name=self.name, + status=RegressionStatus.FAILED, + message=f"{len(mismatches)} edge(s) returned predicates not compatible with the query graph.", + details={"mismatches": mismatches, "count": len(mismatches)}, + ) + return RegressionCheckResult( + name=self.name, + status=RegressionStatus.PASSED, + ) diff --git a/test_harness/result_collector.py b/test_harness/result_collector.py index 0392e58..dce5765 100644 --- a/test_harness/result_collector.py +++ b/test_harness/result_collector.py @@ -290,6 +290,7 @@ def dump_result_summary(self): > Acceptance Test Results: > Passed: {self.acceptance_report['PASSED']}, > Failed: {self.acceptance_report['FAILED']}, +> Regression: {self.acceptance_report['REGRESSION']}, > Skipped: {self.acceptance_report['SKIPPED']} > No Results: {self.acceptance_report['NO_RESULTS']} > Errors: {self.acceptance_report['ERROR']} diff --git a/test_harness/run.py b/test_harness/run.py index 6bb3570..96520ca 100644 --- a/test_harness/run.py +++ b/test_harness/run.py @@ -20,6 +20,10 @@ from test_harness.acceptance_test_runner import run_acceptance_pass_fail_analysis from test_harness.pathfinder_test_runner import pathfinder_pass_fail_analysis from test_harness.performance_test_runner import run_performance_test +from test_harness.regression_checks import ( + RegressionStatus, + run_all as run_regression_checks, +) from test_harness.reporter import Reporter from test_harness.result_collector import ResultCollector from test_harness.runner.generate_query import generate_query @@ -171,6 +175,32 @@ def run_tests( agent_report.status = AgentStatus.FAILED agent_report.message = "Test Error" + try: + check_results = run_regression_checks( + response["response"]["message"], + test_query["query"]["message"].get("query_graph") or {}, + logger, + ) + agent_report.regression_checks.extend(check_results) + failed_checks = [ + r for r in check_results + if r.status == RegressionStatus.FAILED + ] + if failed_checks and agent_report.status == AgentStatus.PASSED: + agent_report.status = AgentStatus.REGRESSION + summary = "; ".join( + f"{r.name}: {r.message or 'failed'}" + for r in failed_checks + ) + agent_report.message = ( + f"{agent_report.message + ' | ' if agent_report.message else ''}" + f"Regression: {summary}" + ) + except Exception as e: + logger.warning( + f"Regression check infrastructure failed on {agent}: {e}" + ) + # grab only ars result if it exists, otherwise default to failed if "ars" not in report.result: status = AgentStatus.SKIPPED diff --git a/test_harness/utils.py b/test_harness/utils.py index 7dfa882..f49ff39 100644 --- a/test_harness/utils.py +++ b/test_harness/utils.py @@ -1,9 +1,12 @@ """General utilities for the Test Harness.""" -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +if TYPE_CHECKING: + from test_harness.regression_checks.base import RegressionCheckResult import httpx from translator_testing_model.datamodel.pydanticmodel import ( @@ -27,6 +30,7 @@ class AgentStatus(str, Enum): NO_RESULTS = "NO_RESULTS" SKIPPED = "SKIPPED" ERROR = "ERROR" + REGRESSION = "REGRESSION" @dataclass @@ -36,13 +40,14 @@ class AgentReport: status: AgentStatus message: Optional[str] actual_output: Optional[dict[str, Optional[int]]] + regression_checks: List["RegressionCheckResult"] = field(default_factory=list) @dataclass class PathfinderReport(AgentReport): """Dictionary for single Pathfinder agent report.""" - expected_nodes_found: str + expected_nodes_found: str = "" @dataclass diff --git a/tests/test_regression_checks.py b/tests/test_regression_checks.py new file mode 100644 index 0000000..fbd76fd --- /dev/null +++ b/tests/test_regression_checks.py @@ -0,0 +1,153 @@ +"""Tests for the regression-check framework and the edge-predicate check.""" + +import copy + +import pytest + +from test_harness.regression_checks import ( + EdgePredicateMatchCheck, + RegressionStatus, + run_all, +) +from test_harness.regression_checks.base import ( + CHECKS, + RegressionCheckResult, +) +from tests.helpers.mock_responses import kp_response + + +def _query_graph(): + return copy.deepcopy(kp_response["message"]["query_graph"]) + + +def _message_with_kg_predicate(predicate: str): + """Return a kp_response copy whose single KG edge has the given predicate.""" + message = copy.deepcopy(kp_response["message"]) + message["knowledge_graph"]["edges"]["n0n1"]["predicate"] = predicate + return message + + +def test_exact_predicate_match_passes(): + check = EdgePredicateMatchCheck() + result = check.run(_message_with_kg_predicate("biolink:treats"), _query_graph()) + assert result.status == RegressionStatus.PASSED + + +def test_descendant_predicate_passes(): + # biolink:ameliorates_condition is a descendant of biolink:treats + check = EdgePredicateMatchCheck() + result = check.run( + _message_with_kg_predicate("biolink:ameliorates_condition"), + _query_graph(), + ) + assert result.status == RegressionStatus.PASSED + + +def test_ancestor_predicate_fails(): + # biolink:affects is an ancestor of biolink:treats — too general + check = EdgePredicateMatchCheck() + result = check.run( + _message_with_kg_predicate("biolink:affects"), + _query_graph(), + ) + assert result.status == RegressionStatus.FAILED + assert result.details["count"] == 1 + assert result.details["mismatches"][0]["predicate"] == "biolink:affects" + + +def test_unrelated_predicate_fails(): + check = EdgePredicateMatchCheck() + result = check.run( + _message_with_kg_predicate("biolink:related_to"), + _query_graph(), + ) + assert result.status == RegressionStatus.FAILED + + +def test_pathfinder_query_graph_skipped(): + # Pathfinder-shaped query graphs have no edges with predicates. + check = EdgePredicateMatchCheck() + qg = {"nodes": {"n0": {}, "n1": {}}, "edges": {}} + result = check.run(_message_with_kg_predicate("biolink:treats"), qg) + assert result.status == RegressionStatus.SKIPPED + + +def test_query_graph_edge_without_predicates_skipped(): + check = EdgePredicateMatchCheck() + qg = _query_graph() + qg["edges"]["n0n1"]["predicates"] = [] + result = check.run(_message_with_kg_predicate("biolink:treats"), qg) + assert result.status == RegressionStatus.SKIPPED + + +def test_dangling_edge_binding_fails(): + check = EdgePredicateMatchCheck() + message = copy.deepcopy(kp_response["message"]) + # remove the kg edge that the binding references + del message["knowledge_graph"]["edges"]["n0n1"] + result = check.run(message, _query_graph()) + assert result.status == RegressionStatus.FAILED + assert result.details["mismatches"][0]["reason"] == "kg_edge_missing" + + +def test_no_results_passes_vacuously(): + check = EdgePredicateMatchCheck() + message = copy.deepcopy(kp_response["message"]) + message["results"] = [] + result = check.run(message, _query_graph()) + assert result.status == RegressionStatus.PASSED + + +def test_toolkit_init_failure_skipped(): + """If bmt cannot initialize, the check skips rather than crashing the run.""" + check = EdgePredicateMatchCheck() + check._toolkit = None + check._toolkit_init_error = "RuntimeError: simulated bmt failure" + # Force _get_toolkit to short-circuit by leaving _toolkit None and + # _toolkit_init_error populated (matches the post-failure state). + original = check._get_toolkit + check._get_toolkit = lambda: None + try: + result = check.run( + _message_with_kg_predicate("biolink:treats"), _query_graph() + ) + finally: + check._get_toolkit = original + assert result.status == RegressionStatus.SKIPPED + assert "simulated bmt failure" in (result.message or "") + + +def test_run_all_isolates_crashing_checks(): + """A buggy check should not crash run_all.""" + + class CrashingCheck: + name = "crasher" + + def run(self, message, query_graph): + raise ValueError("kaboom") + + CHECKS.append(CrashingCheck()) + try: + results = run_all( + _message_with_kg_predicate("biolink:treats"), + _query_graph(), + ) + finally: + CHECKS.pop() + + crasher_results = [r for r in results if r.name == "crasher"] + assert len(crasher_results) == 1 + assert crasher_results[0].status == RegressionStatus.SKIPPED + assert "kaboom" in (crasher_results[0].message or "") + # the real edge_predicate check still ran + assert any(r.name == "edge_predicate_match" for r in results) + + +def test_run_all_returns_one_result_per_check(): + results = run_all( + _message_with_kg_predicate("biolink:treats"), + _query_graph(), + ) + assert len(results) == len(CHECKS) + for r in results: + assert isinstance(r, RegressionCheckResult)