Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/database/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,35 @@ async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection)
)


async def get_tags(id_: int, connection: AsyncConnection) -> list[Row]:
row = await connection.execute(
text(
"""
SELECT *
FROM dataset_tag
WHERE id = :dataset_id
""",
),
parameters={"dataset_id": id_},
)
return list(row.all())


async def untag(id_: int, tag_: str, *, connection: AsyncConnection) -> None:
await connection.execute(
text(
"""
DELETE FROM dataset_tag
WHERE `id` = :dataset_id AND `tag` = :tag
""",
),
parameters={
"dataset_id": id_,
"tag": tag_,
},
)


async def get_description(
id_: int,
connection: AsyncConnection,
Expand Down
35 changes: 35 additions & 0 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
InternalError,
NoResultsError,
TagAlreadyExistsError,
TagNotFoundError,
TagNotOwnedError,
)
from core.formatting import (
_csv_as_list,
Expand Down Expand Up @@ -66,6 +68,39 @@ async def tag_dataset(
}


@router.post(
path="/untag",
)
async def untag_dataset(
data_id: Annotated[int, Body()],
tag: Annotated[str, SystemString64],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
) -> dict[str, dict[str, Any]]:
assert expdb_db is not None # noqa: S101
if not await database.datasets.get(data_id, expdb_db):
msg = f"No dataset with id {data_id} found."
raise DatasetNotFoundError(msg)

dataset_tags = await database.datasets.get_tags(data_id, expdb_db)
matched_tag_row = next((t for t in dataset_tags if t.tag.casefold() == tag.casefold()), None)
if matched_tag_row is None:
msg = f"Dataset {data_id} does not have tag {tag!r}."
raise TagNotFoundError(msg)

if matched_tag_row.uploader != user.user_id and UserGroup.ADMIN not in await user.get_groups():
msg = (
f"You may not remove tag {tag!r} of dataset {data_id} "
"because it was not created by you."
)
raise TagNotOwnedError(msg)

await database.datasets.untag(data_id, matched_tag_row.tag, connection=expdb_db)
return {
"data_untag": {"id": str(data_id)},
}


class DatasetStatusFilter(StrEnum):
ACTIVE = DatasetStatus.ACTIVE
DEACTIVATED = DatasetStatus.DEACTIVATED
Expand Down
88 changes: 88 additions & 0 deletions tests/routers/openml/dataset_tag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,91 @@ async def test_dataset_tag_invalid_tag_is_rejected(

assert new.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert new.json()["detail"][0]["loc"] == ["body", "tag"]


@pytest.mark.parametrize(
"key",
[None, ApiKey.INVALID],
ids=["no authentication", "invalid key"],
)
async def test_dataset_untag_rejects_unauthorized(key: ApiKey, py_api: httpx.AsyncClient) -> None:
apikey = "" if key is None else f"?api_key={key}"
response = await py_api.post(
f"/datasets/untag{apikey}",
json={"data_id": 1, "tag": "study_14"},
)
assert response.status_code == HTTPStatus.UNAUTHORIZED
assert response.headers["content-type"] == "application/problem+json"
error = response.json()
assert error["type"] == AuthenticationFailedError.uri
assert error["code"] == "103"


async def test_dataset_untag(py_api: httpx.AsyncClient, expdb_test: AsyncConnection) -> None:
dataset_id = 1
tag = "temp_dataset_untag"
await py_api.post(
f"/datasets/tag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)

response = await py_api.post(
f"/datasets/untag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.OK
assert response.json() == {"data_untag": {"id": str(dataset_id)}}
assert tag not in await get_tags_for(id_=dataset_id, connection=expdb_test)


async def test_dataset_untag_rejects_other_user(py_api: httpx.AsyncClient) -> None:
dataset_id = 1
tag = "temp_dataset_untag_not_owned"
await py_api.post(
f"/datasets/tag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)

response = await py_api.post(
f"/datasets/untag?api_key={ApiKey.OWNER_USER}",
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.FORBIDDEN
assert response.json()["code"] == "476"
assert "not created by you" in response.json()["detail"]

cleanup = await py_api.post(
f"/datasets/untag?api_key={ApiKey.SOME_USER}",
json={"data_id": dataset_id, "tag": tag},
)
assert cleanup.status_code == HTTPStatus.OK


async def test_dataset_untag_fails_if_tag_does_not_exist(py_api: httpx.AsyncClient) -> None:
dataset_id = 1
tag = "definitely_not_a_dataset_tag"
response = await py_api.post(
f"/datasets/untag?api_key={ApiKey.ADMIN}",
json={"data_id": dataset_id, "tag": tag},
)
assert response.status_code == HTTPStatus.NOT_FOUND
assert response.json()["code"] == "475"
assert "does not have tag" in response.json()["detail"]


@pytest.mark.parametrize(
"tag",
["", "h@", " a", "a" * 65],
ids=["too short", "@", "space", "too long"],
)
async def test_dataset_untag_invalid_tag_is_rejected(
tag: str,
py_api: httpx.AsyncClient,
) -> None:
response = await py_api.post(
f"/datasets/untag?api_key={ApiKey.ADMIN}",
json={"data_id": 1, "tag": tag},
)

assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
assert response.json()["detail"][0]["loc"] == ["body", "tag"]
55 changes: 55 additions & 0 deletions tests/routers/openml/migration/datasets_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,61 @@ async def test_dataset_tag_response_is_identical(
assert original == new


@pytest.mark.parametrize(
"dataset_id",
[1, 2, 3, 101, 131],
)
@pytest.mark.parametrize(
"api_key",
[ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER],
ids=["Administrator", "regular user", "possible owner"],
)
@pytest.mark.parametrize(
"tag",
["study_14", "study_15"],
)
async def test_dataset_untag_response_is_identical(
dataset_id: int,
tag: str,
api_key: str,
py_api: httpx.AsyncClient,
php_api: httpx.AsyncClient,
) -> None:
original = await php_api.post(
"/data/untag",
data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
)
if original.status_code == HTTPStatus.OK:
await php_api.post(
"/data/tag",
data={"api_key": api_key, "tag": tag, "data_id": dataset_id},
)

new = await py_api.post(
f"/datasets/untag?api_key={api_key}",
json={"data_id": dataset_id, "tag": tag},
)

if new.status_code == HTTPStatus.OK:
assert original.status_code == new.status_code, original.json()
assert original.json() == new.json()
return

code, message = original.json()["error"].values()
if message == "Tag is not owned by you":
assert original.status_code == HTTPStatus.PRECONDITION_FAILED
assert new.status_code == HTTPStatus.FORBIDDEN
assert code == new.json()["code"]
assert "not created by you" in new.json()["detail"]
return

assert original.status_code == HTTPStatus.PRECONDITION_FAILED
assert new.status_code == HTTPStatus.NOT_FOUND
assert code == new.json()["code"]
assert message == "Tag not found."
assert "does not have tag" in new.json()["detail"]


@pytest.mark.parametrize(
"data_id",
list(range(1, 130)),
Expand Down