diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 788cd80..2d2772c 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -1,17 +1,19 @@ import json import re -from typing import Annotated, cast +from enum import StrEnum +from typing import Annotated, Any, cast import xmltodict -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Body, Depends from sqlalchemy import RowMapping, text from sqlalchemy.ext.asyncio import AsyncConnection import config import database.datasets import database.tasks -from core.errors import InternalError, TaskNotFoundError -from routers.dependencies import expdb_connection +from core.errors import InternalError, NoResultsError, TaskNotFoundError +from routers.dependencies import Pagination, expdb_connection +from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex from schemas.datasets.openml import Task router = APIRouter(prefix="/tasks", tags=["tasks"]) @@ -157,6 +159,239 @@ async def _fill_json_template( # noqa: C901 return template.replace("[CONSTANT:base_url]", server_url) +class TaskStatusFilter(StrEnum): + """Valid values for the status filter.""" + + ACTIVE = "active" + DEACTIVATED = "deactivated" + IN_PREPARATION = "in_preparation" + ALL = "all" + + +QUALITIES_TO_SHOW = [ + "MajorityClassSize", + "MaxNominalAttDistinctValues", + "MinorityClassSize", + "NumberOfClasses", + "NumberOfFeatures", + "NumberOfInstances", + "NumberOfInstancesWithMissingValues", + "NumberOfMissingValues", + "NumberOfNumericFeatures", + "NumberOfSymbolicFeatures", +] + +BASIC_TASK_INPUTS = [ + "source_data", + "target_feature", + "estimation_procedure", + "evaluation_measures", +] + + +def _quality_clause(quality: str, range_: str | None) -> str: + """Return a SQL WHERE clause fragment filtering tasks by a dataset quality range. + + Looks up tasks whose source dataset has the given quality within the range. + Range can be exact ('100') or a range ('50..200'). + """ + if not range_: + return "" + if not (match := re.match(integer_range_regex, range_)): + msg = f"`range_` not a valid range: {range_}" + raise ValueError(msg) + start, end = match.groups() + # end group looks like "..200", strip the ".." prefix to get just the number + value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}" + # nested subquery: find datasets with matching quality, then find tasks using those datasets + return f""" + AND t.`task_id` IN ( + SELECT ti.`task_id` FROM task_inputs ti + WHERE ti.`input`='source_data' AND ti.`value` IN ( + SELECT `data` FROM data_quality + WHERE `quality`='{quality}' AND {value} + ) + ) + """ # noqa: S608 + + +@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") +@router.get(path="/list") +async def list_tasks( # noqa: PLR0913 + pagination: Annotated[Pagination, Body(default_factory=Pagination)], + task_type_id: Annotated[int | None, Body(description="Filter by task type id.")] = None, + tag: Annotated[str | None, SystemString64] = None, + data_tag: Annotated[str | None, SystemString64] = None, + status: Annotated[TaskStatusFilter, Body()] = TaskStatusFilter.ACTIVE, + task_id: Annotated[list[int] | None, Body(description="Filter by task id(s).")] = None, + data_id: Annotated[list[int] | None, Body(description="Filter by dataset id(s).")] = None, + data_name: Annotated[str | None, CasualString128] = None, + number_instances: Annotated[str | None, IntegerRange] = None, + number_features: Annotated[str | None, IntegerRange] = None, + number_classes: Annotated[str | None, IntegerRange] = None, + number_missing_values: Annotated[str | None, IntegerRange] = None, + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, +) -> list[dict[str, Any]]: + """List tasks, optionally filtered by type, tag, status, dataset properties, and more.""" + assert expdb is not None # noqa: S101 + + # --- WHERE clauses --- + if status == TaskStatusFilter.ALL: + where_status = "" + else: + where_status = f"AND IFNULL(ds.`status`, 'in_preparation') = '{status}'" + + where_type = "" if task_type_id is None else "AND t.`ttid` = :task_type_id" + where_tag = ( + "" if tag is None else "AND t.`task_id` IN (SELECT `id` FROM task_tag WHERE `tag` = :tag)" + ) + where_data_tag = ( + "" + if data_tag is None + else "AND d.`did` IN (SELECT `id` FROM dataset_tag WHERE `tag` = :data_tag)" + ) + task_id_str = ",".join(str(tid) for tid in task_id) if task_id else "" + where_task_id = "" if not task_id else f"AND t.`task_id` IN ({task_id_str})" + data_id_str = ",".join(str(did) for did in data_id) if data_id else "" + where_data_id = "" if not data_id else f"AND d.`did` IN ({data_id_str})" + where_data_name = "" if data_name is None else "AND d.`name` = :data_name" + + where_number_instances = _quality_clause("NumberOfInstances", number_instances) + where_number_features = _quality_clause("NumberOfFeatures", number_features) + where_number_classes = _quality_clause("NumberOfClasses", number_classes) + where_number_missing_values = _quality_clause("NumberOfMissingValues", number_missing_values) + + basic_inputs_str = ", ".join(f"'{i}'" for i in BASIC_TASK_INPUTS) + + # subquery to get the latest status per dataset + # dataset_status has multiple rows per dataset (history), we want only the most recent + status_subquery = """ + SELECT ds1.did, ds1.status + FROM dataset_status ds1 + WHERE ds1.status_date = ( + SELECT MAX(ds2.status_date) FROM dataset_status ds2 + WHERE ds1.did = ds2.did + ) + """ + + query = text( + f""" + SELECT + t.`task_id`, + t.`ttid` AS task_type_id, + tt.`name` AS task_type, + d.`did`, + d.`name`, + d.`format`, + IFNULL(ds.`status`, 'in_preparation') AS status + FROM task t + JOIN task_type tt ON tt.`ttid` = t.`ttid` + JOIN task_inputs ti_source ON ti_source.`task_id` = t.`task_id` + AND ti_source.`input` = 'source_data' + JOIN dataset d ON d.`did` = ti_source.`value` + LEFT JOIN ({status_subquery}) ds ON ds.`did` = d.`did` + WHERE 1=1 + {where_status} + {where_type} + {where_tag} + {where_data_tag} + {where_task_id} + {where_data_id} + {where_data_name} + {where_number_instances} + {where_number_features} + {where_number_classes} + {where_number_missing_values} + GROUP BY t.`task_id`, t.`ttid`, tt.`name`, d.`did`, d.`name`, d.`format`, ds.`status` + ORDER BY t.`task_id` + LIMIT {pagination.limit} OFFSET {pagination.offset} + """, # noqa: S608 + ) + + result = await expdb.execute( + query, + parameters={ + "task_type_id": task_type_id, + "tag": tag, + "data_tag": data_tag, + "data_name": data_name, + }, + ) + rows = result.mappings().all() + + if not rows: + msg = "No tasks match the search criteria." + raise NoResultsError(msg) + + columns = ["task_id", "task_type_id", "task_type", "did", "name", "format", "status"] + tasks: dict[int, dict[str, Any]] = { + row["task_id"]: {col: row[col] for col in columns} for row in rows + } + + # fetch inputs for all tasks in one query + task_ids_str = ",".join(str(tid) for tid in tasks) + inputs_result = await expdb.execute( + text( + f""" + SELECT `task_id`, `input`, `value` + FROM task_inputs + WHERE `task_id` IN ({task_ids_str}) + AND `input` IN ({basic_inputs_str}) + """, # noqa: S608 + ), + ) + for row in inputs_result.all(): + tasks[row.task_id].setdefault("input", []).append( + {"name": row.input, "value": row.value}, + ) + + # fetch qualities for all datasets in one query + did_list = ",".join(str(t["did"]) for t in tasks.values()) + qualities_str = ", ".join(f"'{q}'" for q in QUALITIES_TO_SHOW) + qualities_result = await expdb.execute( + text( + f""" + SELECT `data`, `quality`, `value` + FROM data_quality + WHERE `data` IN ({did_list}) + AND `quality` IN ({qualities_str}) + """, # noqa: S608 + ), + ) + # build a reverse map: dataset_id -> task_id + # needed because quality rows come back keyed by did, but our tasks dict is keyed by task_id + did_to_task_ids: dict[int, list[int]] = {} + for tid, t in tasks.items(): + did_to_task_ids.setdefault(t["did"], []).append(tid) + for row in qualities_result.all(): + for tid in did_to_task_ids.get(row.data, []): + tasks[tid].setdefault("quality", []).append( + {"name": row.quality, "value": str(row.value)}, + ) + + # fetch tags for all tasks in one query + tags_result = await expdb.execute( + text( + f""" + SELECT `id`, `tag` + FROM task_tag + WHERE `id` IN ({task_ids_str}) + """, # noqa: S608 + ), + ) + for row in tags_result.all(): + tasks[row.id].setdefault("tag", []).append(row.tag) + + # ensure every task has all expected keys(input/quality/tag) even if no rows were found for them + # e.g. a task with no tags should return "tag": [] not missing key + for task in tasks.values(): + task.setdefault("input", []) + task.setdefault("quality", []) + task.setdefault("tag", []) + + return list(tasks.values()) + + @router.get("/{task_id}") async def get_task( task_id: int, diff --git a/tests/routers/openml/task_test.py b/tests/routers/openml/task_test.py index e78bba8..c2f0094 100644 --- a/tests/routers/openml/task_test.py +++ b/tests/routers/openml/task_test.py @@ -4,6 +4,104 @@ import httpx +async def test_list_tasks_default(py_api: httpx.AsyncClient) -> None: + """Default call returns active tasks with correct shape.""" + response = await py_api.post("/tasks/list", json={}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert isinstance(tasks, list) + assert len(tasks) > 0 + assert all(task["status"] == "active" for task in tasks) + # verify shape of first task + task = tasks[0] + assert "task_id" in task + assert "task_type_id" in task + assert "task_type" in task + assert "did" in task + assert "name" in task + assert "format" in task + assert "status" in task + assert "input" in task + assert "quality" in task + assert "tag" in task + + +async def test_list_tasks_filter_type(py_api: httpx.AsyncClient) -> None: + """Filter by task_type_id returns only tasks of that type.""" + response = await py_api.post("/tasks/list", json={"task_type_id": 1}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + assert all(t["task_type_id"] == 1 for t in tasks) + + +async def test_list_tasks_filter_tag(py_api: httpx.AsyncClient) -> None: + """Filter by tag returns only tasks with that tag.""" + response = await py_api.post("/tasks/list", json={"tag": "OpenML100"}) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + assert all("OpenML100" in t["tag"] for t in tasks) + + +async def test_list_tasks_pagination(py_api: httpx.AsyncClient) -> None: + """Pagination returns correct number of results.""" + limit = 5 + response = await py_api.post( + "/tasks/list", + json={"pagination": {"limit": limit, "offset": 0}}, + ) + assert response.status_code == HTTPStatus.OK + assert len(response.json()) == limit + + +async def test_list_tasks_pagination_offset(py_api: httpx.AsyncClient) -> None: + """Offset returns different results than no offset.""" + r1 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 0}}) + r2 = await py_api.post("/tasks/list", json={"pagination": {"limit": 5, "offset": 5}}) + ids1 = [t["task_id"] for t in r1.json()] + ids2 = [t["task_id"] for t in r2.json()] + assert ids1 != ids2 + + +async def test_list_tasks_number_instances_range(py_api: httpx.AsyncClient) -> None: + """number_instances range filter returns tasks whose dataset matches.""" + min_instances, max_instances = 100, 1000 + response = await py_api.post( + "/tasks/list", + json={"number_instances": f"{min_instances}..{max_instances}"}, + ) + assert response.status_code == HTTPStatus.OK + tasks = response.json() + assert len(tasks) > 0 + for task in tasks: + qualities = {q["name"]: q["value"] for q in task["quality"]} + assert min_instances <= float(qualities["NumberOfInstances"]) <= max_instances + + +async def test_list_tasks_no_results(py_api: httpx.AsyncClient) -> None: + """Nonexistent tag returns 404 NoResultsError.""" + response = await py_api.post("/tasks/list", json={"tag": "nonexistent_tag_xyz"}) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.headers["content-type"] == "application/problem+json" + error = response.json() + assert error["status"] == HTTPStatus.NOT_FOUND + assert error["code"] == "372" + + +async def test_list_tasks_get(py_api: httpx.AsyncClient) -> None: + """GET /tasks/list with no body also works.""" + response = await py_api.get("/tasks/list") + assert response.status_code == HTTPStatus.OK + assert isinstance(response.json(), list) + + +async def test_list_tasks_invalid_range_format(py_api: httpx.AsyncClient) -> None: + """Invalid number_instances range returns 422 validation error.""" + response = await py_api.post("/tasks/list", json={"number_instances": "1...2"}) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + async def test_get_task(py_api: httpx.AsyncClient) -> None: response = await py_api.get("/tasks/59") assert response.status_code == HTTPStatus.OK