diff --git a/src/database/datasets.py b/src/database/datasets.py index 4e76dcf9..561869ab 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -68,6 +68,35 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) ) +async def get_tags(id_: int, connection: AsyncConnection) -> list[Row]: + row = await connection.execute( + text( + """ + SELECT * + FROM dataset_tag + WHERE id = :dataset_id + """, + ), + parameters={"dataset_id": id_}, + ) + return list(row.all()) + + +async def untag(id_: int, tag_: str, *, connection: AsyncConnection) -> None: + await connection.execute( + text( + """ + DELETE FROM dataset_tag + WHERE `id` = :dataset_id AND `tag` = :tag + """, + ), + parameters={ + "dataset_id": id_, + "tag": tag_, + }, + ) + + async def get_description( id_: int, connection: AsyncConnection, diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index d86ed848..d2ff7f1a 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -25,6 +25,8 @@ InternalError, NoResultsError, TagAlreadyExistsError, + TagNotFoundError, + TagNotOwnedError, ) from core.formatting import ( _csv_as_list, @@ -66,6 +68,39 @@ async def tag_dataset( } +@router.post( + path="/untag", +) +async def untag_dataset( + data_id: Annotated[int, Body()], + tag: Annotated[str, SystemString64], + user: Annotated[User, Depends(fetch_user_or_raise)], + expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, +) -> dict[str, dict[str, Any]]: + assert expdb_db is not None # noqa: S101 + if not await database.datasets.get(data_id, expdb_db): + msg = f"No dataset with id {data_id} found." + raise DatasetNotFoundError(msg) + + dataset_tags = await database.datasets.get_tags(data_id, expdb_db) + matched_tag_row = next((t for t in dataset_tags if t.tag.casefold() == tag.casefold()), None) + if matched_tag_row is None: + msg = f"Dataset {data_id} does not have tag {tag!r}." + raise TagNotFoundError(msg) + + if matched_tag_row.uploader != user.user_id and UserGroup.ADMIN not in await user.get_groups(): + msg = ( + f"You may not remove tag {tag!r} of dataset {data_id} " + "because it was not created by you." + ) + raise TagNotOwnedError(msg) + + await database.datasets.untag(data_id, matched_tag_row.tag, connection=expdb_db) + return { + "data_untag": {"id": str(data_id)}, + } + + class DatasetStatusFilter(StrEnum): ACTIVE = DatasetStatus.ACTIVE DEACTIVATED = DatasetStatus.DEACTIVATED diff --git a/tests/routers/openml/dataset_tag_test.py b/tests/routers/openml/dataset_tag_test.py index 25042c89..586428f2 100644 --- a/tests/routers/openml/dataset_tag_test.py +++ b/tests/routers/openml/dataset_tag_test.py @@ -89,3 +89,91 @@ async def test_dataset_tag_invalid_tag_is_rejected( assert new.status_code == HTTPStatus.UNPROCESSABLE_ENTITY assert new.json()["detail"][0]["loc"] == ["body", "tag"] + + +@pytest.mark.parametrize( + "key", + [None, ApiKey.INVALID], + ids=["no authentication", "invalid key"], +) +async def test_dataset_untag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None: + apikey = "" if key is None else f"?api_key={key}" + response = await py_api.post( + f"/datasets/untag{apikey}", + json={"data_id": 1, "tag": "study_14"}, + ) + assert response.status_code == HTTPStatus.UNAUTHORIZED + assert response.headers["content-type"] == "application/problem+json" + error = response.json() + assert error["type"] == AuthenticationFailedError.uri + assert error["code"] == "103" + + +async def test_dataset_untag(py_api: httpx.AsyncClient, expdb_test: AsyncConnection) -> None: + dataset_id = 1 + tag = "temp_dataset_untag" + await py_api.post( + f"/datasets/tag?api_key={ApiKey.SOME_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + + response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.SOME_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.OK + assert response.json() == {"data_untag": {"id": str(dataset_id)}} + assert tag not in await get_tags_for(id_=dataset_id, connection=expdb_test) + + +async def test_dataset_untag_rejects_other_user(py_api: httpx.AsyncClient) -> None: + dataset_id = 1 + tag = "temp_dataset_untag_not_owned" + await py_api.post( + f"/datasets/tag?api_key={ApiKey.SOME_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + + response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.OWNER_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.FORBIDDEN + assert response.json()["code"] == "476" + assert "not created by you" in response.json()["detail"] + + cleanup = await py_api.post( + f"/datasets/untag?api_key={ApiKey.SOME_USER}", + json={"data_id": dataset_id, "tag": tag}, + ) + assert cleanup.status_code == HTTPStatus.OK + + +async def test_dataset_untag_fails_if_tag_does_not_exist(py_api: httpx.AsyncClient) -> None: + dataset_id = 1 + tag = "definitely_not_a_dataset_tag" + response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.ADMIN}", + json={"data_id": dataset_id, "tag": tag}, + ) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.json()["code"] == "475" + assert "does not have tag" in response.json()["detail"] + + +@pytest.mark.parametrize( + "tag", + ["", "h@", " a", "a" * 65], + ids=["too short", "@", "space", "too long"], +) +async def test_dataset_untag_invalid_tag_is_rejected( + tag: str, + py_api: httpx.AsyncClient, +) -> None: + response = await py_api.post( + f"/datasets/untag?api_key={ApiKey.ADMIN}", + json={"data_id": 1, "tag": tag}, + ) + + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + assert response.json()["detail"][0]["loc"] == ["body", "tag"] diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py index 5ff6fe86..6cf39001 100644 --- a/tests/routers/openml/migration/datasets_migration_test.py +++ b/tests/routers/openml/migration/datasets_migration_test.py @@ -226,6 +226,61 @@ async def test_dataset_tag_response_is_identical( assert original == new +@pytest.mark.parametrize( + "dataset_id", + [1, 2, 3, 101, 131], +) +@pytest.mark.parametrize( + "api_key", + [ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER], + ids=["Administrator", "regular user", "possible owner"], +) +@pytest.mark.parametrize( + "tag", + ["study_14", "study_15"], +) +async def test_dataset_untag_response_is_identical( + dataset_id: int, + tag: str, + api_key: str, + py_api: httpx.AsyncClient, + php_api: httpx.AsyncClient, +) -> None: + original = await php_api.post( + "/data/untag", + data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, + ) + if original.status_code == HTTPStatus.OK: + await php_api.post( + "/data/tag", + data={"api_key": api_key, "tag": tag, "data_id": dataset_id}, + ) + + new = await py_api.post( + f"/datasets/untag?api_key={api_key}", + json={"data_id": dataset_id, "tag": tag}, + ) + + if new.status_code == HTTPStatus.OK: + assert original.status_code == new.status_code, original.json() + assert original.json() == new.json() + return + + code, message = original.json()["error"].values() + if message == "Tag is not owned by you": + assert original.status_code == HTTPStatus.PRECONDITION_FAILED + assert new.status_code == HTTPStatus.FORBIDDEN + assert code == new.json()["code"] + assert "not created by you" in new.json()["detail"] + return + + assert original.status_code == HTTPStatus.PRECONDITION_FAILED + assert new.status_code == HTTPStatus.NOT_FOUND + assert code == new.json()["code"] + assert message == "Tag not found." + assert "does not have tag" in new.json()["detail"] + + @pytest.mark.parametrize( "data_id", list(range(1, 130)),