diff --git a/api/main.py b/api/main.py index 17b10ac1..19c674a0 100644 --- a/api/main.py +++ b/api/main.py @@ -17,7 +17,7 @@ import traceback from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import pymongo from beanie import PydanticObjectId @@ -53,6 +53,8 @@ KernelVersion, Node, PublishEvent, + ResultValues, + StateValues, TelemetryEvent, parse_node_obj, ) @@ -1649,6 +1651,84 @@ async def put_node( return obj +class NodePatchRequest(BaseModel): + """Request model for partial node updates""" + + state: Optional[StateValues] = None + result: Optional[ResultValues] = None + artifacts: Optional[Dict[str, str]] = None + data: Optional[Dict[str, Any]] = None + debug: Optional[Dict[str, Any]] = None + jobfilter: Optional[List[str]] = None + platform_filter: Optional[List[str]] = None + timeout: Optional[datetime] = None + holdoff: Optional[datetime] = None + processed_by_kcidb_bridge: Optional[bool] = None + + +@app.patch("/node/{node_id}", response_model=Node, response_model_by_alias=False) +async def patch_node( + node_id: str, + patch: NodePatchRequest, + user: str = Depends(authorize_user), + noevent: Optional[bool] = Query(None), +): + """Partial update of an existing node""" + metrics.add("http_requests_total", 1) + node_from_id = await db.find_by_id(Node, node_id) + if not node_from_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Node not found with id: {node_id}", + ) + + update_data = patch.model_dump(exclude_unset=True) + if not update_data: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No fields to update", + ) + + # Handle state transition separately + new_state = update_data.pop("state", None) + + # Apply non-state fields to existing node + if update_data: + new_node_def = node_from_id.model_copy(update=update_data) + else: + new_node_def = node_from_id.model_copy() + + # Validate node subtype + specialized_node = parse_node_obj(new_node_def) + + # State transition checks + if new_state is not None: + is_valid, message = specialized_node.validate_node_state_transition( + new_state + ) + if not is_valid: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=message + ) + if new_state != new_node_def.state: + new_node_def.processed_by_kcidb_bridge = False + new_node_def.state = new_state + + # KCIDB flags reset logic + if "processed_by_kcidb_bridge" not in patch.model_dump(exclude_unset=True): + new_node_def.processed_by_kcidb_bridge = False + + # Update node in the DB + obj = await db.update(new_node_def) + data = _get_node_event_data("updated", obj) + attributes = {} + if data.get("owner", None): + attributes["owner"] = data["owner"] + if not noevent: + await pubsub.publish_cloudevent("node", data, attributes) + return obj + + class NodeUpdateRequest(BaseModel): """Request model for updating multiple nodes""" diff --git a/tests/e2e_tests/test_node_handler.py b/tests/e2e_tests/test_node_handler.py index 252dcac1..f331a343 100644 --- a/tests/e2e_tests/test_node_handler.py +++ b/tests/e2e_tests/test_node_handler.py @@ -91,3 +91,23 @@ async def update_node(test_async_client, node): ) assert response.status_code == 200 assert response.json().keys() == node_model_fields + + +async def patch_node(test_async_client, node_id, patch_data): + """ + Test Case : Test KernelCI API PATCH /node/{node_id} endpoint + Expected Result : + HTTP Response Code 200 OK + JSON with updated Node object + """ + response = await test_async_client.patch( + f"node/{node_id}", + headers={ + "Accept": "application/json", + "Authorization": f"Bearer {pytest.BEARER_TOKEN}", # pylint: disable=no-member + }, + data=json.dumps(patch_data), + ) + assert response.status_code == 200 + assert response.json().keys() == node_model_fields + return response diff --git a/tests/e2e_tests/test_pipeline.py b/tests/e2e_tests/test_pipeline.py index 7b0c12e4..f87099cc 100644 --- a/tests/e2e_tests/test_pipeline.py +++ b/tests/e2e_tests/test_pipeline.py @@ -10,7 +10,7 @@ from cloudevents.http import from_json from .listen_handler import create_listen_task -from .test_node_handler import create_node, get_node_by_id, update_node +from .test_node_handler import create_node, get_node_by_id, patch_node, update_node @pytest.mark.dependency(