From df5105c6b71a2cec8050f94a797de51838b6631e Mon Sep 17 00:00:00 2001 From: root Date: Fri, 15 May 2026 08:27:14 +0000 Subject: [PATCH 1/6] feat: add authz dependency injection --- src/blueapi/service/authentication.py | 9 +++++++++ src/blueapi/service/main.py | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index b107f7b2b..33201c9f2 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -272,3 +272,12 @@ def get_access_token(self): def sync_auth_flow(self, request): request.headers["Authorization"] = f"Bearer {self.get_access_token()}" yield request + + +class OPAClient: # placeholder until https://jira.diamond.ac.uk/browse/ACQP-550 is done + def do_some_checks(self, task_request) -> bool: + return True + + +def get_opa_client() -> OPAClient: # placeholder + return OPAClient() diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index a53c46885..75325236e 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -37,6 +37,7 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface +from blueapi.service.authentication import OPAClient, get_opa_client from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum @@ -278,6 +279,16 @@ def get_device_by_name( ) +def submission_check( + opa: Annotated[OPAClient, Depends(get_opa_client)], + task_request: TaskRequest, +): + allowed = opa.do_some_checks(task_request) + + if not allowed: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + @secure_router_v1.post("/tasks", status_code=status.HTTP_201_CREATED, tags=[Tag.TASK]) @secure_router.post("/tasks", status_code=status.HTTP_201_CREATED, tags=[Tag.TASK]) @start_as_current_span( @@ -291,6 +302,7 @@ def submit_task( request: Request, response: Response, task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], + authz_check: Annotated[None, Depends(submission_check)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: """Submit a task to the worker.""" From e7fe75f47868370925b521b1209e0df97949e2b3 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 18 May 2026 15:01:34 +0000 Subject: [PATCH 2/6] feat: add auth check dependency injections to task endpoints --- src/blueapi/service/main.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 75325236e..2075b7a74 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -167,6 +167,16 @@ def inner(request: Request, access_token: str = Depends(oauth_scheme)): TRACER = get_tracer("interface") +def submit_permission( + opa: Annotated[OPAClient, Depends(get_opa_client)], + task_request: TaskRequest, +): + allowed = opa.do_some_checks(task_request) + + if not allowed: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -279,16 +289,6 @@ def get_device_by_name( ) -def submission_check( - opa: Annotated[OPAClient, Depends(get_opa_client)], - task_request: TaskRequest, -): - allowed = opa.do_some_checks(task_request) - - if not allowed: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - - @secure_router_v1.post("/tasks", status_code=status.HTTP_201_CREATED, tags=[Tag.TASK]) @secure_router.post("/tasks", status_code=status.HTTP_201_CREATED, tags=[Tag.TASK]) @start_as_current_span( @@ -302,7 +302,7 @@ def submit_task( request: Request, response: Response, task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], - authz_check: Annotated[None, Depends(submission_check)], + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: """Submit a task to the worker.""" @@ -348,6 +348,7 @@ def submit_task( @start_as_current_span(TRACER, "task_id") def delete_submitted_task( task_id: str, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) @@ -366,6 +367,7 @@ def validate_task_status(v: str) -> TaskStatusEnum: @start_as_current_span(TRACER) def get_tasks( runner: Annotated[WorkerDispatcher, Depends(_runner)], + _: Annotated[None, Depends(submit_permission)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: """ @@ -402,6 +404,7 @@ def get_tasks( def set_active_task( request: Request, task: WorkerTask, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerTask: """Set a task to active status, the worker should begin it as soon as possible. @@ -432,6 +435,7 @@ def get_passthrough_headers(request: Request) -> dict[str, str]: @start_as_current_span(TRACER, "task_id") def get_task( task_id: str, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TrackableTask: """Retrieve a task""" @@ -509,6 +513,7 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt def set_state( state_change_request: StateChangeRequest, response: Response, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ From 05c58e895c13201eb77395be49c2c1fe705f1394 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 08:13:42 +0000 Subject: [PATCH 3/6] feat: create new access task permission fns and add as dependencies --- .vscode/launch.json | 5 +++- src/blueapi/service/main.py | 57 +++++++++++++++++++++++++++++++++---- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 2308cfec6..fb5b50b33 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -34,7 +34,10 @@ "module": "blueapi", "args": "--config ${input:config_path} serve", "env": { - "OTLP_EXPORT_ENABLED": "false" + "OTLP_EXPORT_ENABLED": "false", + "EPICS_CA_NAME_SERVERS": "127.0.0.1:9064", + "EPICS_PVA_NAME_SERVERS": "127.0.0.1:9075", + "EPICS_CA_ADDR_LIST": "127.0.0.1:9064" }, }, { diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 2075b7a74..996a04c73 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -177,6 +177,41 @@ def submit_permission( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) +def access_task_permission( + request: Request, + task_id: str, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + try: + task = runner.run(interface.get_task_by_id, task_id) + except KeyError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None + + if ( + access_token + and task + and access_token.get("fedid") != task.task.metadata.get("user") + ): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + + +# start_task_permission is used when there is WorkerTask +def start_task_permission( + request: Request, + task: WorkerTask, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + if not task.task_id: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="No task id provided", + ) + access_task_permission(request, task.task_id, runner) + + async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -348,7 +383,7 @@ def submit_task( @start_as_current_span(TRACER, "task_id") def delete_submitted_task( task_id: str, - _: Annotated[None, Depends(submit_permission)], + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) @@ -366,8 +401,8 @@ def validate_task_status(v: str) -> TaskStatusEnum: @secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @start_as_current_span(TRACER) def get_tasks( + request: Request, runner: Annotated[WorkerDispatcher, Depends(_runner)], - _: Annotated[None, Depends(submit_permission)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: """ @@ -387,6 +422,15 @@ def get_tasks( tasks = runner.run(interface.get_tasks_by_status, desired_status) else: tasks = runner.run(interface.get_tasks) + + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + user = access_token.get("fedid") if access_token else None + + if user: + tasks = [t for t in tasks if t.task.metadata.get("user") == user] + return TasksListResponse(tasks=tasks) @@ -404,7 +448,7 @@ def get_tasks( def set_active_task( request: Request, task: WorkerTask, - _: Annotated[None, Depends(submit_permission)], + _: Annotated[None, Depends(start_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerTask: """Set a task to active status, the worker should begin it as soon as possible. @@ -435,7 +479,7 @@ def get_passthrough_headers(request: Request) -> dict[str, str]: @start_as_current_span(TRACER, "task_id") def get_task( task_id: str, - _: Annotated[None, Depends(submit_permission)], + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TrackableTask: """Retrieve a task""" @@ -513,7 +557,7 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt def set_state( state_change_request: StateChangeRequest, response: Response, - _: Annotated[None, Depends(submit_permission)], + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ @@ -545,6 +589,9 @@ def set_state( elif new_state == WorkerState.RUNNING: runner.run(interface.resume_worker) elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}: + # active = runner.run(interface.get_active_task) + # if active.task.metadata.get("user"): + try: runner.run( interface.cancel_active_task, From 30ada257c9d0c2ed7f0feaee3312deea0d992784 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 08:30:53 +0000 Subject: [PATCH 4/6] refactor: update rest api version --- src/blueapi/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 83d6d7021..e932ff84b 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -303,7 +303,7 @@ class ApplicationConfig(BlueapiBaseModel): """ #: API version to publish in OpenAPI schema - REST_API_VERSION: ClassVar[str] = "1.3.0" + REST_API_VERSION: ClassVar[str] = "1.3.1" LICENSE_INFO: ClassVar[dict[str, str]] = { "name": "Apache 2.0", From b5f0e062312aa944348e6cc59a02a269f4e80122 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 10:13:46 +0000 Subject: [PATCH 5/6] comment out dependency addition in set_state --- src/blueapi/config.py | 2 +- src/blueapi/service/main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/blueapi/config.py b/src/blueapi/config.py index e932ff84b..83d6d7021 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -303,7 +303,7 @@ class ApplicationConfig(BlueapiBaseModel): """ #: API version to publish in OpenAPI schema - REST_API_VERSION: ClassVar[str] = "1.3.1" + REST_API_VERSION: ClassVar[str] = "1.3.0" LICENSE_INFO: ClassVar[dict[str, str]] = { "name": "Apache 2.0", diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 996a04c73..2c94a1158 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -557,7 +557,7 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt def set_state( state_change_request: StateChangeRequest, response: Response, - _: Annotated[None, Depends(access_task_permission)], + # _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ From 27a9865f2fe38ead633f568954ae559f420b1a80 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 14:16:24 +0000 Subject: [PATCH 6/6] refactor: add admin check and check to set state function --- pyproject.toml | 2 +- src/blueapi/service/authentication.py | 3 ++ src/blueapi/service/main.py | 35 ++++++++++++----------- tests/unit_tests/service/test_rest_api.py | 2 +- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 659779994..baceb70eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ filterwarnings = ["error", "ignore::DeprecationWarning"] # Doctest python code in docs, python code in src docstrings, test functions in tests testpaths = "docs src tests" asyncio_mode = "auto" -timeout = 3 +timeout = 100 [tool.coverage.run] patch = ["subprocess"] diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index 33201c9f2..61f439d5f 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -278,6 +278,9 @@ class OPAClient: # placeholder until https://jira.diamond.ac.uk/browse/ACQP-550 def do_some_checks(self, task_request) -> bool: return True + def admin(self): + return False + def get_opa_client() -> OPAClient: # placeholder return OPAClient() diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 2c94a1158..857a6b82b 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -178,6 +178,7 @@ def submit_permission( def access_task_permission( + opa: Annotated[OPAClient, Depends(get_opa_client)], request: Request, task_id: str, runner: Annotated[WorkerDispatcher, Depends(_runner)], @@ -185,21 +186,19 @@ def access_task_permission( access_token: dict[str, Any] | None = getattr( request.state, "decoded_access_token", None ) - try: - task = runner.run(interface.get_task_by_id, task_id) - except KeyError: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None + task = runner.run(interface.get_task_by_id, task_id) - if ( + if not opa.admin() and ( access_token and task and access_token.get("fedid") != task.task.metadata.get("user") ): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) # start_task_permission is used when there is WorkerTask def start_task_permission( + opa: Annotated[OPAClient, Depends(get_opa_client)], request: Request, task: WorkerTask, runner: Annotated[WorkerDispatcher, Depends(_runner)], @@ -209,7 +208,7 @@ def start_task_permission( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail="No task id provided", ) - access_task_permission(request, task.task_id, runner) + access_task_permission(opa, request, task.task_id, runner) async def on_key_error_404(_: Request, __: Exception): @@ -346,10 +345,7 @@ def submit_task( access_token: dict[str, Any] | None = getattr( request.state, "decoded_access_token", None ) - if access_token: - user: str = access_token.get("fedid", "Unknown") - else: - user = "Unknown" + user = access_token.get("fedid") if access_token else None task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) response.headers["Location"] = f"{request.url}/{task_id}" @@ -428,8 +424,7 @@ def get_tasks( ) user = access_token.get("fedid") if access_token else None - if user: - tasks = [t for t in tasks if t.task.metadata.get("user") == user] + tasks = [t for t in tasks if t.task.metadata.get("user") == user] return TasksListResponse(tasks=tasks) @@ -555,8 +550,10 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt ) @start_as_current_span(TRACER, "state_change_request.new_state") def set_state( + request: Request, state_change_request: StateChangeRequest, response: Response, + opa: Annotated[OPAClient, Depends(get_opa_client)], # _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: @@ -584,14 +581,20 @@ def set_state( current_state in _ALLOWED_TRANSITIONS and new_state in _ALLOWED_TRANSITIONS[current_state] ): + active = runner.run(interface.get_active_task) + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + user = access_token.get("fedid") if access_token else None + + if not opa.admin() and active and active.task.metadata.get("user") != user: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + if new_state == WorkerState.PAUSED: runner.run(interface.pause_worker, state_change_request.defer) elif new_state == WorkerState.RUNNING: runner.run(interface.resume_worker) elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}: - # active = runner.run(interface.get_active_task) - # if active.task.metadata.get("user"): - try: runner.run( interface.cancel_active_task, diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index c1d3b6a95..bf0a6a997 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -251,7 +251,7 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: response = client.post("/tasks", json=task.model_dump()) - mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"}) + mock_runner.run.assert_called_with(submit_task, task, {"user": None}) assert response.json() == {"task_id": task_id}