diff --git a/src/core/tagging.py b/src/core/tagging.py new file mode 100644 index 00000000..8a69567d --- /dev/null +++ b/src/core/tagging.py @@ -0,0 +1,51 @@ +from collections.abc import Awaitable, Callable +from typing import Any + +from sqlalchemy import Row +from sqlalchemy.ext.asyncio import AsyncConnection + +from core.errors import TagAlreadyExistsError, TagNotFoundError, TagNotOwnedError +from database.users import User, UserGroup + + +async def tag_entity( + entity_id: int, + tag: str, + user: User, + expdb: AsyncConnection, + *, + get_tags_fn: Callable[[int, AsyncConnection], Awaitable[list[str]]], + tag_fn: Callable[..., Awaitable[None]], + response_key: str, +) -> dict[str, dict[str, Any]]: + tags = await get_tags_fn(entity_id, expdb) + if tag.casefold() in (t.casefold() for t in tags): + msg = f"Entity {entity_id} already tagged with {tag!r}." + raise TagAlreadyExistsError(msg) + await tag_fn(entity_id, tag, user_id=user.user_id, expdb=expdb) + tags = await get_tags_fn(entity_id, expdb) + return {response_key: {"id": str(entity_id), "tag": tags}} + + +async def untag_entity( + entity_id: int, + tag: str, + user: User, + expdb: AsyncConnection, + *, + get_tag_fn: Callable[[int, str, AsyncConnection], Awaitable[Row | None]], + delete_tag_fn: Callable[[int, str, AsyncConnection], Awaitable[None]], + get_tags_fn: Callable[[int, AsyncConnection], Awaitable[list[str]]], + response_key: str, +) -> dict[str, dict[str, Any]]: + existing = await get_tag_fn(entity_id, tag, expdb) + if existing is None: + msg = f"Tag {tag!r} not found on entity {entity_id}." + raise TagNotFoundError(msg) + groups = await user.get_groups() + if existing.uploader != user.user_id and UserGroup.ADMIN not in groups: + msg = f"Tag {tag!r} on entity {entity_id} is not owned by you." + raise TagNotOwnedError(msg) + await delete_tag_fn(entity_id, tag, expdb) + tags = await get_tags_fn(entity_id, expdb) + return {response_key: {"id": str(entity_id), "tag": tags}} diff --git a/src/database/flows.py b/src/database/flows.py index 79bb6e5b..79d4f5b9 100644 --- a/src/database/flows.py +++ b/src/database/flows.py @@ -4,6 +4,11 @@ from sqlalchemy import Row, text from sqlalchemy.ext.asyncio import AsyncConnection +from database.tagging import insert_tag, remove_tag, select_tag, select_tags + +_TABLE = "implementation_tag" +_ID_COLUMN = "id" + async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]: rows = await expdb.execute( @@ -23,18 +28,7 @@ async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]: async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]: - rows = await expdb.execute( - text( - """ - SELECT tag - FROM implementation_tag - WHERE id = :flow_id - """, - ), - parameters={"flow_id": flow_id}, - ) - tag_rows = rows.all() - return [tag.tag for tag in tag_rows] + return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=flow_id, expdb=expdb) async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]: @@ -54,6 +48,25 @@ async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]: ) +async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None: + await insert_tag( + table=_TABLE, + id_column=_ID_COLUMN, + id_=id_, + tag_=tag_, + user_id=user_id, + expdb=expdb, + ) + + +async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None: + return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb) + + +async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None: + await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb) + + async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None: """Get flow by name and external version.""" row = await expdb.execute( diff --git a/src/database/runs.py b/src/database/runs.py new file mode 100644 index 00000000..98d92e91 --- /dev/null +++ b/src/database/runs.py @@ -0,0 +1,44 @@ +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection + +from database.tagging import insert_tag, remove_tag, select_tag, select_tags + +_TABLE = "run_tag" +_ID_COLUMN = "id" + + +async def get(id_: int, expdb: AsyncConnection) -> Row | None: + row = await expdb.execute( + text( + """ + SELECT * + FROM run + WHERE `id` = :run_id + """, + ), + parameters={"run_id": id_}, + ) + return row.one_or_none() + + +async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]: + return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=id_, expdb=expdb) + + +async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None: + await insert_tag( + table=_TABLE, + id_column=_ID_COLUMN, + id_=id_, + tag_=tag_, + user_id=user_id, + expdb=expdb, + ) + + +async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None: + return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb) + + +async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None: + await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb) diff --git a/src/database/tagging.py b/src/database/tagging.py new file mode 100644 index 00000000..8aeef10d --- /dev/null +++ b/src/database/tagging.py @@ -0,0 +1,82 @@ +from sqlalchemy import Row, text +from sqlalchemy.ext.asyncio import AsyncConnection + + +async def insert_tag( + *, + table: str, + id_column: str, + id_: int, + tag_: str, + user_id: int, + expdb: AsyncConnection, +) -> None: + await expdb.execute( + text( + f""" + INSERT INTO {table}(`{id_column}`, `tag`, `uploader`) + VALUES (:id, :tag, :user_id) + """, + ), + parameters={"id": id_, "tag": tag_, "user_id": user_id}, + ) + + +async def select_tag( + *, + table: str, + id_column: str, + id_: int, + tag_: str, + expdb: AsyncConnection, +) -> Row | None: + result = await expdb.execute( + text( + f""" + SELECT `{id_column}` as id, `tag`, `uploader` + FROM {table} + WHERE `{id_column}` = :id AND `tag` = :tag + """, + ), + parameters={"id": id_, "tag": tag_}, + ) + return result.one_or_none() + + +async def remove_tag( + *, + table: str, + id_column: str, + id_: int, + tag_: str, + expdb: AsyncConnection, +) -> None: + await expdb.execute( + text( + f""" + DELETE FROM {table} + WHERE `{id_column}` = :id AND `tag` = :tag + """, + ), + parameters={"id": id_, "tag": tag_}, + ) + + +async def select_tags( + *, + table: str, + id_column: str, + id_: int, + expdb: AsyncConnection, +) -> list[str]: + result = await expdb.execute( + text( + f""" + SELECT `tag` + FROM {table} + WHERE `{id_column}` = :id + """, + ), + parameters={"id": id_}, + ) + return [row.tag for row in result.all()] diff --git a/src/database/tasks.py b/src/database/tasks.py index e9670d26..1a48d647 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -4,6 +4,11 @@ from sqlalchemy import Row, text from sqlalchemy.ext.asyncio import AsyncConnection +from database.tagging import insert_tag, remove_tag, select_tag, select_tags + +_TABLE = "task_tag" +_ID_COLUMN = "id" + async def get(id_: int, expdb: AsyncConnection) -> Row | None: row = await expdb.execute( @@ -103,15 +108,23 @@ async def get_task_type_inout_with_template( async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]: - rows = await expdb.execute( - text( - """ - SELECT `tag` - FROM task_tag - WHERE `id` = :task_id - """, - ), - parameters={"task_id": id_}, + return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=id_, expdb=expdb) + + +async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None: + await insert_tag( + table=_TABLE, + id_column=_ID_COLUMN, + id_=id_, + tag_=tag_, + user_id=user_id, + expdb=expdb, ) - tag_rows = rows.all() - return [row.tag for row in tag_rows] + + +async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None: + return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb) + + +async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None: + await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb) diff --git a/src/main.py b/src/main.py index 76a52ad3..f30a80c5 100644 --- a/src/main.py +++ b/src/main.py @@ -15,6 +15,7 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router +from routers.openml.runs import router as runs_router from routers.openml.setups import router as setup_router from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router @@ -68,6 +69,7 @@ def create_api() -> FastAPI: app.include_router(estimationprocedure_router) app.include_router(task_router) app.include_router(flows_router) + app.include_router(runs_router) app.include_router(study_router) app.include_router(setup_router) return app diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index 41254863..8eaa0765 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -1,17 +1,57 @@ -from typing import Annotated, Literal +from typing import Annotated, Any, Literal -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Body, Depends from sqlalchemy.ext.asyncio import AsyncConnection import database.flows from core.conversions import _str_to_num from core.errors import FlowNotFoundError -from routers.dependencies import expdb_connection +from core.tagging import tag_entity, untag_entity +from database.users import User +from routers.dependencies import expdb_connection, fetch_user_or_raise +from routers.types import SystemString64 from schemas.flows import Flow, Parameter, Subflow router = APIRouter(prefix="/flows", tags=["flows"]) +@router.post(path="/tag") +async def tag_flow( + flow_id: Annotated[int, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[str, dict[str, Any]]: + return await tag_entity( + flow_id, + tag, + user, + expdb, + get_tags_fn=database.flows.get_tags, + tag_fn=database.flows.tag, + response_key="flow_tag", + ) + + +@router.post(path="/untag") +async def untag_flow( + flow_id: Annotated[int, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[str, dict[str, Any]]: + return await untag_entity( + flow_id, + tag, + user, + expdb, + get_tag_fn=database.flows.get_tag, + delete_tag_fn=database.flows.delete_tag, + get_tags_fn=database.flows.get_tags, + response_key="flow_tag", + ) + + @router.get("/exists/{name}/{external_version}") async def flow_exists( name: str, diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py new file mode 100644 index 00000000..fdea8f6d --- /dev/null +++ b/src/routers/openml/runs.py @@ -0,0 +1,49 @@ +from typing import Annotated, Any + +from fastapi import APIRouter, Body, Depends +from sqlalchemy.ext.asyncio import AsyncConnection + +import database.runs +from core.tagging import tag_entity, untag_entity +from database.users import User +from routers.dependencies import expdb_connection, fetch_user_or_raise +from routers.types import SystemString64 + +router = APIRouter(prefix="/runs", tags=["runs"]) + + +@router.post(path="/tag") +async def tag_run( + run_id: Annotated[int, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[str, dict[str, Any]]: + return await tag_entity( + run_id, + tag, + user, + expdb, + get_tags_fn=database.runs.get_tags, + tag_fn=database.runs.tag, + response_key="run_tag", + ) + + +@router.post(path="/untag") +async def untag_run( + run_id: Annotated[int, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[str, dict[str, Any]]: + return await untag_entity( + run_id, + tag, + user, + expdb, + get_tag_fn=database.runs.get_tag, + delete_tag_fn=database.runs.delete_tag, + get_tags_fn=database.runs.get_tags, + response_key="run_tag", + ) diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 788cd804..e7a19a66 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -1,9 +1,9 @@ import json import re -from typing import Annotated, cast +from typing import Annotated, Any, cast import xmltodict -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Body, Depends from sqlalchemy import RowMapping, text from sqlalchemy.ext.asyncio import AsyncConnection @@ -11,7 +11,10 @@ import database.datasets import database.tasks from core.errors import InternalError, TaskNotFoundError -from routers.dependencies import expdb_connection +from core.tagging import tag_entity, untag_entity +from database.users import User +from routers.dependencies import expdb_connection, fetch_user_or_raise +from routers.types import SystemString64 from schemas.datasets.openml import Task router = APIRouter(prefix="/tasks", tags=["tasks"]) @@ -157,6 +160,43 @@ async def _fill_json_template( # noqa: C901 return template.replace("[CONSTANT:base_url]", server_url) +@router.post(path="/tag") +async def tag_task( + task_id: Annotated[int, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[str, dict[str, Any]]: + return await tag_entity( + task_id, + tag, + user, + expdb, + get_tags_fn=database.tasks.get_tags, + tag_fn=database.tasks.tag, + response_key="task_tag", + ) + + +@router.post(path="/untag") +async def untag_task( + task_id: Annotated[int, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb: Annotated[AsyncConnection, Depends(expdb_connection)], +) -> dict[str, dict[str, Any]]: + return await untag_entity( + task_id, + tag, + user, + expdb, + get_tag_fn=database.tasks.get_tag, + delete_tag_fn=database.tasks.delete_tag, + get_tags_fn=database.tasks.get_tags, + response_key="task_tag", + ) + + @router.get("/{task_id}") async def get_task( task_id: int, diff --git a/tests/routers/openml/flow_tag_test.py b/tests/routers/openml/flow_tag_test.py new file mode 100644 index 00000000..f9193154 --- /dev/null +++ b/tests/routers/openml/flow_tag_test.py @@ -0,0 +1,151 @@ +from http import HTTPStatus + +import httpx +import pytest +from sqlalchemy.ext.asyncio import AsyncConnection + +from database.flows import get_tags +from tests.conftest import Flow +from tests.users import ApiKey + + +@pytest.mark.parametrize( + "key", + [None, ApiKey.INVALID], + ids=["no authentication", "invalid key"], +) +async def test_flow_tag_rejects_unauthorized( + key: ApiKey | None, + py_api: httpx.AsyncClient, +) -> None: + apikey = "" if key is None else f"?api_key={key}" + response = await py_api.post( + f"/flows/tag{apikey}", + json={"flow_id": 1, "tag": "test"}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +async def test_flow_tag( + flow: Flow, + expdb_test: AsyncConnection, + py_api: httpx.AsyncClient, +) -> None: + tag = "test" + response = await py_api.post( + f"/flows/tag?api_key={ApiKey.ADMIN}", + json={"flow_id": flow.id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + assert response.json() == {"flow_tag": {"id": str(flow.id), "tag": [tag]}} + + tags = await get_tags(flow_id=flow.id, expdb=expdb_test) + assert tag in tags + + +async def test_flow_tag_returns_existing_tags(py_api: httpx.AsyncClient) -> None: + flow_id, tag = 1, "test" + response = await py_api.post( + f"/flows/tag?api_key={ApiKey.ADMIN}", + json={"flow_id": flow_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + result = response.json() + assert result["flow_tag"]["id"] == str(flow_id) + assert "OpenmlWeka" in result["flow_tag"]["tag"] + assert "weka" in result["flow_tag"]["tag"] + assert tag in result["flow_tag"]["tag"] + + +async def test_flow_tag_fails_if_tag_exists(py_api: httpx.AsyncClient) -> None: + flow_id, tag = 1, "OpenmlWeka" + response = await py_api.post( + f"/flows/tag?api_key={ApiKey.ADMIN}", + json={"flow_id": flow_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.CONFLICT + + +@pytest.mark.parametrize( + "key", + [None, ApiKey.INVALID], + ids=["no authentication", "invalid key"], +) +async def test_flow_untag_rejects_unauthorized( + key: ApiKey | None, + py_api: httpx.AsyncClient, +) -> None: + apikey = "" if key is None else f"?api_key={key}" + response = await py_api.post( + f"/flows/untag{apikey}", + json={"flow_id": 1, "tag": "test"}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +async def test_flow_untag( + flow: Flow, + expdb_test: AsyncConnection, + py_api: httpx.AsyncClient, +) -> None: + tag = "test" + setup = await py_api.post( + f"/flows/tag?api_key={ApiKey.ADMIN}", + json={"flow_id": flow.id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/flows/untag?api_key={ApiKey.ADMIN}", + json={"flow_id": flow.id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + assert response.json() == {"flow_tag": {"id": str(flow.id), "tag": []}} + + tags = await get_tags(flow_id=flow.id, expdb=expdb_test) + assert tag not in tags + + +async def test_flow_untag_fails_if_tag_not_found(py_api: httpx.AsyncClient) -> None: + response = await py_api.post( + f"/flows/untag?api_key={ApiKey.ADMIN}", + json={"flow_id": 1, "tag": "nonexistent"}, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + + +async def test_flow_untag_non_admin_own_tag( + flow: Flow, + expdb_test: AsyncConnection, + py_api: httpx.AsyncClient, +) -> None: + tag = "user_tag" + setup = await py_api.post( + f"/flows/tag?api_key={ApiKey.SOME_USER}", + json={"flow_id": flow.id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/flows/untag?api_key={ApiKey.SOME_USER}", + json={"flow_id": flow.id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + + tags = await get_tags(flow_id=flow.id, expdb=expdb_test) + assert tag not in tags + + +async def test_flow_untag_fails_if_not_owner( + flow: Flow, + py_api: httpx.AsyncClient, +) -> None: + tag = "test" + setup = await py_api.post( + f"/flows/tag?api_key={ApiKey.ADMIN}", + json={"flow_id": flow.id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/flows/untag?api_key={ApiKey.SOME_USER}", + json={"flow_id": flow.id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.FORBIDDEN diff --git a/tests/routers/openml/run_tag_test.py b/tests/routers/openml/run_tag_test.py new file mode 100644 index 00000000..dcde7466 --- /dev/null +++ b/tests/routers/openml/run_tag_test.py @@ -0,0 +1,165 @@ +from http import HTTPStatus + +import httpx +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncConnection + +from database.runs import get_tags +from tests.users import ApiKey + + +@pytest.fixture +async def run_id(expdb_test: AsyncConnection) -> int: + await expdb_test.execute( + text( + """ + INSERT INTO run(`uploader`, `task_id`, `setup`) + VALUES (1, 59, 1); + """, + ), + ) + result = await expdb_test.execute(text("SELECT LAST_INSERT_ID();")) + (rid,) = result.one() + return int(rid) + + +@pytest.mark.parametrize( + "key", + [None, ApiKey.INVALID], + ids=["no authentication", "invalid key"], +) +async def test_run_tag_rejects_unauthorized( + key: ApiKey | None, + run_id: int, + py_api: httpx.AsyncClient, +) -> None: + apikey = "" if key is None else f"?api_key={key}" + response = await py_api.post( + f"/runs/tag{apikey}", + json={"run_id": run_id, "tag": "test"}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +async def test_run_tag( + run_id: int, + expdb_test: AsyncConnection, + py_api: httpx.AsyncClient, +) -> None: + tag = "test" + response = await py_api.post( + f"/runs/tag?api_key={ApiKey.ADMIN}", + json={"run_id": run_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + assert response.json() == {"run_tag": {"id": str(run_id), "tag": [tag]}} + + tags = await get_tags(id_=run_id, expdb=expdb_test) + assert tag in tags + + +async def test_run_tag_fails_if_tag_exists( + run_id: int, + py_api: httpx.AsyncClient, +) -> None: + tag = "test" + setup = await py_api.post( + f"/runs/tag?api_key={ApiKey.ADMIN}", + json={"run_id": run_id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/runs/tag?api_key={ApiKey.ADMIN}", + json={"run_id": run_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.CONFLICT + + +@pytest.mark.parametrize( + "key", + [None, ApiKey.INVALID], + ids=["no authentication", "invalid key"], +) +async def test_run_untag_rejects_unauthorized( + key: ApiKey | None, + run_id: int, + py_api: httpx.AsyncClient, +) -> None: + apikey = "" if key is None else f"?api_key={key}" + response = await py_api.post( + f"/runs/untag{apikey}", + json={"run_id": run_id, "tag": "test"}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +async def test_run_untag( + run_id: int, + expdb_test: AsyncConnection, + py_api: httpx.AsyncClient, +) -> None: + tag = "test" + setup = await py_api.post( + f"/runs/tag?api_key={ApiKey.ADMIN}", + json={"run_id": run_id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/runs/untag?api_key={ApiKey.ADMIN}", + json={"run_id": run_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + assert response.json() == {"run_tag": {"id": str(run_id), "tag": []}} + + tags = await get_tags(id_=run_id, expdb=expdb_test) + assert tag not in tags + + +async def test_run_untag_fails_if_tag_not_found( + run_id: int, + py_api: httpx.AsyncClient, +) -> None: + response = await py_api.post( + f"/runs/untag?api_key={ApiKey.ADMIN}", + json={"run_id": run_id, "tag": "nonexistent"}, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + + +async def test_run_untag_non_admin_own_tag( + run_id: int, + expdb_test: AsyncConnection, + py_api: httpx.AsyncClient, +) -> None: + tag = "user_tag" + setup = await py_api.post( + f"/runs/tag?api_key={ApiKey.SOME_USER}", + json={"run_id": run_id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/runs/untag?api_key={ApiKey.SOME_USER}", + json={"run_id": run_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + + tags = await get_tags(id_=run_id, expdb=expdb_test) + assert tag not in tags + + +async def test_run_untag_fails_if_not_owner( + run_id: int, + py_api: httpx.AsyncClient, +) -> None: + tag = "test" + setup = await py_api.post( + f"/runs/tag?api_key={ApiKey.ADMIN}", + json={"run_id": run_id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/runs/untag?api_key={ApiKey.SOME_USER}", + json={"run_id": run_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.FORBIDDEN diff --git a/tests/routers/openml/task_tag_test.py b/tests/routers/openml/task_tag_test.py new file mode 100644 index 00000000..85da529c --- /dev/null +++ b/tests/routers/openml/task_tag_test.py @@ -0,0 +1,141 @@ +from http import HTTPStatus + +import httpx +import pytest +from sqlalchemy.ext.asyncio import AsyncConnection + +from database.tasks import get_tags +from tests.users import ApiKey + + +@pytest.mark.parametrize( + "key", + [None, ApiKey.INVALID], + ids=["no authentication", "invalid key"], +) +async def test_task_tag_rejects_unauthorized( + key: ApiKey | None, + py_api: httpx.AsyncClient, +) -> None: + apikey = "" if key is None else f"?api_key={key}" + response = await py_api.post( + f"/tasks/tag{apikey}", + json={"task_id": 59, "tag": "test"}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +@pytest.mark.parametrize( + "key", + [ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER], + ids=["administrator", "non-owner", "owner"], +) +async def test_task_tag( + key: ApiKey, + expdb_test: AsyncConnection, + py_api: httpx.AsyncClient, +) -> None: + task_id, tag = 59, "test" + response = await py_api.post( + f"/tasks/tag?api_key={key}", + json={"task_id": task_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + assert response.json() == {"task_tag": {"id": str(task_id), "tag": [tag]}} + + tags = await get_tags(id_=task_id, expdb=expdb_test) + assert tag in tags + + +async def test_task_tag_fails_if_tag_exists(py_api: httpx.AsyncClient) -> None: + task_id, tag = 59, "test" + setup = await py_api.post( + f"/tasks/tag?api_key={ApiKey.ADMIN}", + json={"task_id": task_id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/tasks/tag?api_key={ApiKey.ADMIN}", + json={"task_id": task_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.CONFLICT + + +@pytest.mark.parametrize( + "key", + [None, ApiKey.INVALID], + ids=["no authentication", "invalid key"], +) +async def test_task_untag_rejects_unauthorized( + key: ApiKey | None, + py_api: httpx.AsyncClient, +) -> None: + apikey = "" if key is None else f"?api_key={key}" + response = await py_api.post( + f"/tasks/untag{apikey}", + json={"task_id": 59, "tag": "test"}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +async def test_task_untag( + expdb_test: AsyncConnection, + py_api: httpx.AsyncClient, +) -> None: + task_id, tag = 59, "test" + setup = await py_api.post( + f"/tasks/tag?api_key={ApiKey.ADMIN}", + json={"task_id": task_id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/tasks/untag?api_key={ApiKey.ADMIN}", + json={"task_id": task_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + assert response.json() == {"task_tag": {"id": str(task_id), "tag": []}} + + tags = await get_tags(id_=task_id, expdb=expdb_test) + assert tag not in tags + + +async def test_task_untag_fails_if_tag_not_found(py_api: httpx.AsyncClient) -> None: + response = await py_api.post( + f"/tasks/untag?api_key={ApiKey.ADMIN}", + json={"task_id": 59, "tag": "nonexistent"}, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + + +async def test_task_untag_non_admin_own_tag( + expdb_test: AsyncConnection, + py_api: httpx.AsyncClient, +) -> None: + task_id, tag = 59, "user_tag" + setup = await py_api.post( + f"/tasks/tag?api_key={ApiKey.SOME_USER}", + json={"task_id": task_id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/tasks/untag?api_key={ApiKey.SOME_USER}", + json={"task_id": task_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + + tags = await get_tags(id_=task_id, expdb=expdb_test) + assert tag not in tags + + +async def test_task_untag_fails_if_not_owner(py_api: httpx.AsyncClient) -> None: + task_id, tag = 59, "test" + setup = await py_api.post( + f"/tasks/tag?api_key={ApiKey.ADMIN}", + json={"task_id": task_id, "tag": tag}, + ) + assert setup.status_code == HTTPStatus.OK + response = await py_api.post( + f"/tasks/untag?api_key={ApiKey.SOME_USER}", + json={"task_id": task_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.FORBIDDEN