Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/core/tagging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from collections.abc import Awaitable, Callable
from typing import Any

from sqlalchemy import Row
from sqlalchemy.ext.asyncio import AsyncConnection

from core.errors import TagAlreadyExistsError, TagNotFoundError, TagNotOwnedError
from database.users import User, UserGroup


async def tag_entity(
entity_id: int,
tag: str,
user: User,
expdb: AsyncConnection,
*,
get_tags_fn: Callable[[int, AsyncConnection], Awaitable[list[str]]],
tag_fn: Callable[..., Awaitable[None]],
response_key: str,
) -> dict[str, dict[str, Any]]:
tags = await get_tags_fn(entity_id, expdb)
if tag.casefold() in (t.casefold() for t in tags):
msg = f"Entity {entity_id} already tagged with {tag!r}."
raise TagAlreadyExistsError(msg)
await tag_fn(entity_id, tag, user_id=user.user_id, expdb=expdb)
tags = await get_tags_fn(entity_id, expdb)
return {response_key: {"id": str(entity_id), "tag": tags}}


async def untag_entity(
entity_id: int,
tag: str,
user: User,
expdb: AsyncConnection,
*,
get_tag_fn: Callable[[int, str, AsyncConnection], Awaitable[Row | None]],
delete_tag_fn: Callable[[int, str, AsyncConnection], Awaitable[None]],
get_tags_fn: Callable[[int, AsyncConnection], Awaitable[list[str]]],
response_key: str,
) -> dict[str, dict[str, Any]]:
existing = await get_tag_fn(entity_id, tag, expdb)
if existing is None:
msg = f"Tag {tag!r} not found on entity {entity_id}."
raise TagNotFoundError(msg)
groups = await user.get_groups()
if existing.uploader != user.user_id and UserGroup.ADMIN not in groups:
msg = f"Tag {tag!r} on entity {entity_id} is not owned by you."
raise TagNotOwnedError(msg)
await delete_tag_fn(entity_id, tag, expdb)
tags = await get_tags_fn(entity_id, expdb)
return {response_key: {"id": str(entity_id), "tag": tags}}
37 changes: 25 additions & 12 deletions src/database/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

from database.tagging import insert_tag, remove_tag, select_tag, select_tags

_TABLE = "implementation_tag"
_ID_COLUMN = "id"


async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
rows = await expdb.execute(
Expand All @@ -23,18 +28,7 @@ async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:


async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]:
rows = await expdb.execute(
text(
"""
SELECT tag
FROM implementation_tag
WHERE id = :flow_id
""",
),
parameters={"flow_id": flow_id},
)
tag_rows = rows.all()
return [tag.tag for tag in tag_rows]
return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=flow_id, expdb=expdb)


async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
Expand All @@ -54,6 +48,25 @@ async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
)


async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None:
await insert_tag(
table=_TABLE,
id_column=_ID_COLUMN,
id_=id_,
tag_=tag_,
user_id=user_id,
expdb=expdb,
)


async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None:
return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)


async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None:
await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)


async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None:
"""Get flow by name and external version."""
row = await expdb.execute(
Expand Down
44 changes: 44 additions & 0 deletions src/database/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

from database.tagging import insert_tag, remove_tag, select_tag, select_tags

_TABLE = "run_tag"
_ID_COLUMN = "id"


async def get(id_: int, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
text(
"""
SELECT *
FROM run
WHERE `id` = :run_id
""",
),
parameters={"run_id": id_},
)
return row.one_or_none()


async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]:
return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=id_, expdb=expdb)


async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None:
await insert_tag(
table=_TABLE,
id_column=_ID_COLUMN,
id_=id_,
tag_=tag_,
user_id=user_id,
expdb=expdb,
)


async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None:
return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)


async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None:
await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)
82 changes: 82 additions & 0 deletions src/database/tagging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection


async def insert_tag(
*,
table: str,
id_column: str,
Comment on lines +7 to +8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Constrain table and id_column before interpolating them into shared SQL.

These helpers are now the reusable tagging boundary, but they accept arbitrary identifier strings and splice them straight into every query. The current callers pass constants, yet one future non-constant call turns this module into an injection sink. Please validate against a closed set of supported identifier combinations here, instead of relying on every caller to stay disciplined.

Also applies to: 16-19, 35-39, 56-59, 74-78

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/database/tagging.py` around lines 7 - 8, The helpers in this module
accept the identifier parameters table and id_column and interpolate them
directly into shared SQL; validate these inputs against a closed allowlist
before any string interpolation by adding an allowlist mapping (e.g.,
ALLOWED_TABLES = {"tags": {"id"}, "items": {"item_id"}}) and checking that table
is a key and id_column is one of its allowed columns, raising ValueError on
mismatch; perform this check at the start of the public helper functions that
accept table/id_column so no query is constructed for invalid values, and when
you must embed identifiers into SQL use a safe identifier mechanism (e.g.,
psycopg2.sql.Identifier or the DB driver's proper identifier-quoting helper)
rather than naive string concatenation.

id_: int,
tag_: str,
user_id: int,
expdb: AsyncConnection,
) -> None:
await expdb.execute(
text(
f"""
INSERT INTO {table}(`{id_column}`, `tag`, `uploader`)
VALUES (:id, :tag, :user_id)
""",
),
parameters={"id": id_, "tag": tag_, "user_id": user_id},
)


async def select_tag(
*,
table: str,
id_column: str,
id_: int,
tag_: str,
expdb: AsyncConnection,
) -> Row | None:
result = await expdb.execute(
text(
f"""
SELECT `{id_column}` as id, `tag`, `uploader`
FROM {table}
WHERE `{id_column}` = :id AND `tag` = :tag
""",
),
parameters={"id": id_, "tag": tag_},
)
return result.one_or_none()


async def remove_tag(
*,
table: str,
id_column: str,
id_: int,
tag_: str,
expdb: AsyncConnection,
) -> None:
await expdb.execute(
text(
f"""
DELETE FROM {table}
WHERE `{id_column}` = :id AND `tag` = :tag
""",
),
parameters={"id": id_, "tag": tag_},
)
Comment on lines +46 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bind the DELETE to the checked owner.

The new /untag flow does the ownership check before it calls delete_tag, but remove_tag deletes by (id, tag) only. If that row is removed and recreated by someone else between the check and this DELETE, the original requester can still delete the new owner's tag. Make the mutation conditional on the expected uploader (or row PK) so authorization and deletion stay atomic.

🔒 Suggested direction
 async def remove_tag(
     *,
     table: str,
     id_column: str,
     id_: int,
     tag_: str,
+    uploader: int | None = None,
     expdb: AsyncConnection,
 ) -> None:
+    uploader_clause = " AND `uploader` = :uploader" if uploader is not None else ""
     await expdb.execute(
         text(
             f"""
             DELETE FROM {table}
-            WHERE `{id_column}` = :id AND `tag` = :tag
+            WHERE `{id_column}` = :id AND `tag` = :tag{uploader_clause}
             """,
         ),
-        parameters={"id": id_, "tag": tag_},
+        parameters={
+            "id": id_,
+            "tag": tag_,
+            **({"uploader": uploader} if uploader is not None else {}),
+        },
     )

Then thread the expected uploader through core.tagging.untag_entity for non-admin deletes.

🧰 Tools
🪛 Ruff (0.15.6)

[error] 56-59: Possible SQL injection vector through string-based query construction

(S608)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/database/tagging.py` around lines 46 - 62, remove_tag currently deletes
rows by (id, tag) only which allows TOCTOU deletion by another uploader; change
remove_tag to accept an expected_uploader (or expected_row_pk) parameter and
include it in the DELETE WHERE clause (e.g., AND `uploader` = :expected_uploader
or AND primary_key = :expected_pk) so the DELETE is conditional/atomic; update
callers (notably core.tagging.untag_entity and any pathway used by non-admins)
to pass the checked uploader value when calling remove_tag so authorization is
enforced in the same SQL statement.



async def select_tags(
*,
table: str,
id_column: str,
id_: int,
expdb: AsyncConnection,
) -> list[str]:
result = await expdb.execute(
text(
f"""
SELECT `tag`
FROM {table}
WHERE `{id_column}` = :id
""",
),
parameters={"id": id_},
)
return [row.tag for row in result.all()]
35 changes: 24 additions & 11 deletions src/database/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

from database.tagging import insert_tag, remove_tag, select_tag, select_tags

_TABLE = "task_tag"
_ID_COLUMN = "id"


async def get(id_: int, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
Expand Down Expand Up @@ -103,15 +108,23 @@ async def get_task_type_inout_with_template(


async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]:
rows = await expdb.execute(
text(
"""
SELECT `tag`
FROM task_tag
WHERE `id` = :task_id
""",
),
parameters={"task_id": id_},
return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=id_, expdb=expdb)


async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None:
await insert_tag(
table=_TABLE,
id_column=_ID_COLUMN,
id_=id_,
tag_=tag_,
user_id=user_id,
expdb=expdb,
)
tag_rows = rows.all()
return [row.tag for row in tag_rows]


async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None:
return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)


async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None:
await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from routers.openml.evaluations import router as evaluationmeasures_router
from routers.openml.flows import router as flows_router
from routers.openml.qualities import router as qualities_router
from routers.openml.runs import router as runs_router
from routers.openml.setups import router as setup_router
from routers.openml.study import router as study_router
from routers.openml.tasks import router as task_router
Expand Down Expand Up @@ -68,6 +69,7 @@ def create_api() -> FastAPI:
app.include_router(estimationprocedure_router)
app.include_router(task_router)
app.include_router(flows_router)
app.include_router(runs_router)
app.include_router(study_router)
app.include_router(setup_router)
return app
Expand Down
46 changes: 43 additions & 3 deletions src/routers/openml/flows.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,57 @@
from typing import Annotated, Literal
from typing import Annotated, Any, Literal

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Body, Depends
from sqlalchemy.ext.asyncio import AsyncConnection

import database.flows
from core.conversions import _str_to_num
from core.errors import FlowNotFoundError
from routers.dependencies import expdb_connection
from core.tagging import tag_entity, untag_entity
from database.users import User
from routers.dependencies import expdb_connection, fetch_user_or_raise
from routers.types import SystemString64
from schemas.flows import Flow, Parameter, Subflow

router = APIRouter(prefix="/flows", tags=["flows"])


@router.post(path="/tag")
async def tag_flow(
flow_id: Annotated[int, Body()],
tag: Annotated[str, SystemString64],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, Any]]:
return await tag_entity(
flow_id,
tag,
user,
expdb,
get_tags_fn=database.flows.get_tags,
tag_fn=database.flows.tag,
response_key="flow_tag",
)


@router.post(path="/untag")
async def untag_flow(
flow_id: Annotated[int, Body()],
tag: Annotated[str, SystemString64],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, Any]]:
return await untag_entity(
flow_id,
tag,
user,
expdb,
get_tag_fn=database.flows.get_tag,
delete_tag_fn=database.flows.delete_tag,
get_tags_fn=database.flows.get_tags,
response_key="flow_tag",
)


@router.get("/exists/{name}/{external_version}")
async def flow_exists(
name: str,
Expand Down
Loading