diff --git a/src/core/errors.py b/src/core/errors.py index 3f53364a..a6191915 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -385,3 +385,26 @@ class InternalError(ProblemDetailError): uri = "https://openml.org/problems/internal-error" title = "Internal Server Error" _default_status_code = HTTPStatus.INTERNAL_SERVER_ERROR + + +# ============================================================================= +# Run Errors +# ============================================================================= + + +class RunNotFoundError(ProblemDetailError): + """Raised when a run cannot be found.""" + + uri = "https://openml.org/problems/run-not-found" + title = "Run Not Found" + _default_status_code = HTTPStatus.PRECONDITION_FAILED + _default_code = 571 + + +class RunTraceNotFoundError(ProblemDetailError): + """Raised when trace data for a run cannot be found.""" + + uri = "https://openml.org/problems/run-trace-not-found" + title = "Run Trace Not Found" + _default_status_code = HTTPStatus.PRECONDITION_FAILED + _default_code = 572 diff --git a/src/database/runs.py b/src/database/runs.py new file mode 100644 index 00000000..acf7a532 --- /dev/null +++ b/src/database/runs.py @@ -0,0 +1,40 @@ +"""Database queries for run-related data.""" + +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection + + +async def exist(id_: int, expdb: AsyncConnection) -> bool: + """Check if a run exists by ID.""" + row = await expdb.execute( + text( + """ + SELECT 1 + FROM `run` + WHERE `rid` = :run_id + """, + ), + parameters={"run_id": id_}, + ) + return bool(row.one_or_none()) + + +async def get_trace(run_id: int, expdb: AsyncConnection) -> Sequence[Row]: + """Get trace rows for a run from the trace table.""" + rows = await expdb.execute( + text( + """ + SELECT `repeat`, `fold`, `iteration`, `setup_string`, `evaluation`, `selected` + FROM `trace` + WHERE `run_id` = :run_id + """, + ), + parameters={"run_id": run_id}, + ) + return cast( + "Sequence[Row]", + rows.all(), + ) diff --git a/src/main.py b/src/main.py index 76a52ad3..8ffecd01 100644 --- a/src/main.py +++ b/src/main.py @@ -15,6 +15,7 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router +from routers.openml.runs import router as run_router from routers.openml.setups import router as setup_router from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router @@ -70,6 +71,7 @@ def create_api() -> FastAPI: app.include_router(flows_router) app.include_router(study_router) app.include_router(setup_router) + app.include_router(run_router) return app diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py new file mode 100644 index 00000000..37a7cecf --- /dev/null +++ b/src/routers/openml/runs.py @@ -0,0 +1,44 @@ +"""Endpoints for run-related data.""" + +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncConnection + +import database.runs +from core.errors import RunNotFoundError, RunTraceNotFoundError +from routers.dependencies import expdb_connection +from schemas.runs import RunTrace, TraceIteration + +router = APIRouter(prefix="/run", tags=["run"]) + + +@router.get("/trace/{run_id}") +async def get_run_trace( + run_id: int, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> RunTrace: + """Get trace data for a run by run ID.""" + if not await database.runs.exist(run_id, expdb): + msg = f"Run {run_id} not found." + raise RunNotFoundError(msg) + + trace_rows = await database.runs.get_trace(run_id, expdb) + if not trace_rows: + msg = f"No trace found for run {run_id}." + raise RunTraceNotFoundError(msg) + + return RunTrace( + run_id=run_id, + trace=[ + TraceIteration( + repeat=row.repeat, + fold=row.fold, + iteration=row.iteration, + setup_string=row.setup_string, + evaluation=row.evaluation, + selected=row.selected, + ) + for row in trace_rows + ], + ) diff --git a/src/schemas/runs.py b/src/schemas/runs.py new file mode 100644 index 00000000..857f4921 --- /dev/null +++ b/src/schemas/runs.py @@ -0,0 +1,21 @@ +"""Pydantic schemas for run-related endpoints.""" + +from pydantic import BaseModel + + +class TraceIteration(BaseModel): + """A single trace iteration for a run.""" + + repeat: int + fold: int + iteration: int + setup_string: str | None + evaluation: float | None + selected: str + + +class RunTrace(BaseModel): + """Trace data for a run.""" + + run_id: int + trace: list[TraceIteration] diff --git a/tests/routers/openml/migration/runs_migration_test.py b/tests/routers/openml/migration/runs_migration_test.py new file mode 100644 index 00000000..60555959 --- /dev/null +++ b/tests/routers/openml/migration/runs_migration_test.py @@ -0,0 +1,43 @@ +"""Migration tests comparing PHP and Python API responses for run trace endpoints.""" + +import asyncio +from http import HTTPStatus + +import deepdiff +import httpx +import pytest + +from core.conversions import nested_num_to_str + + +@pytest.mark.parametrize("run_id", [34]) +async def test_get_run_trace_equal( + run_id: int, + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + """Test that Python and PHP run trace responses are equivalent after normalization.""" + py_response, php_response = await asyncio.gather( + py_api.get(f"/run/trace/{run_id}"), + php_api.get(f"/run/trace/{run_id}"), + ) + assert py_response.status_code == HTTPStatus.OK + assert php_response.status_code == HTTPStatus.OK + + new_json = py_response.json() + + # PHP nests response under "trace" key — match that structure + new_json = {"trace": new_json} + + # PHP uses "trace_iteration" key, Python uses "trace" + new_json["trace"]["trace_iteration"] = new_json["trace"].pop("trace") + + # PHP returns all numeric values as strings — normalize Python response + new_json = nested_num_to_str(new_json) + + differences = deepdiff.diff.DeepDiff( + new_json, + php_response.json(), + ignore_order=True, + ) + assert not differences diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py new file mode 100644 index 00000000..2dc62edd --- /dev/null +++ b/tests/routers/openml/runs_test.py @@ -0,0 +1,49 @@ +"""Tests for the GET /run/trace/{run_id} endpoint.""" + +from http import HTTPStatus + +import httpx +import pytest + +from core.errors import RunNotFoundError, RunTraceNotFoundError + + +@pytest.mark.parametrize("run_id", [34]) +async def test_get_run_trace_success(run_id: int, py_api: httpx.AsyncClient) -> None: + """Test that trace data is returned for a run that has trace entries.""" + response = await py_api.get(f"/run/trace/{run_id}") + assert response.status_code == HTTPStatus.OK + body = response.json() + assert body["run_id"] == run_id + assert isinstance(body["trace"], list) + assert len(body["trace"]) > 0 + first = body["trace"][0] + assert isinstance(first["repeat"], int) + assert isinstance(first["fold"], int) + assert isinstance(first["iteration"], int) + assert first["selected"] in ("true", "false") + assert first["evaluation"] is None or isinstance(first["evaluation"], float) + + +@pytest.mark.parametrize("run_id", [24]) +async def test_get_run_trace_no_trace(run_id: int, py_api: httpx.AsyncClient) -> None: + """Test that 412 is returned for a run that exists but has no trace.""" + response = await py_api.get(f"/run/trace/{run_id}") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + body = response.json() + assert body["code"] == "572" # RunTraceNotFoundError code + assert body["type"] == RunTraceNotFoundError.uri + assert body["title"] == RunTraceNotFoundError.title + assert body["status"] == HTTPStatus.PRECONDITION_FAILED + + +@pytest.mark.parametrize("run_id", [999999]) +async def test_get_run_trace_run_not_found(run_id: int, py_api: httpx.AsyncClient) -> None: + """Test that 412 is returned when the run does not exist.""" + response = await py_api.get(f"/run/trace/{run_id}") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + body = response.json() + assert body["code"] == "571" # RunNotFoundError code + assert body["type"] == RunNotFoundError.uri + assert body["title"] == RunNotFoundError.title + assert body["status"] == HTTPStatus.PRECONDITION_FAILED