diff --git a/src/itential_mcp/tools/operations_manager.py b/src/itential_mcp/tools/operations_manager.py index 5744d27..566f200 100644 --- a/src/itential_mcp/tools/operations_manager.py +++ b/src/itential_mcp/tools/operations_manager.py @@ -2,7 +2,11 @@ # GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Annotated +from __future__ import annotations + +import json + +from typing import Annotated, Any from pydantic import Field @@ -18,6 +22,45 @@ __tags__ = ("operations_manager",) +def _coerce_value(value: Any, schema: dict) -> Any: + """Coerce a single value to the type declared in its JSON Schema definition. + + LLMs often stringify nested structures (arrays, objects) when calling tools. + This converts them back to the correct Python type using the schema as ground truth. + If the value cannot be parsed or the schema has no type declaration, it is returned + unchanged so the platform produces the real error rather than a silent wrong cast. + """ + expected_type = schema.get("type") + if not isinstance(value, str) or expected_type not in ("array", "object"): + return value + + try: + parsed = json.loads(value) + except (json.JSONDecodeError, ValueError): + return value + + if expected_type == "array" and isinstance(parsed, list): + return parsed + if expected_type == "object" and isinstance(parsed, dict): + return parsed + + return value + + +def _coerce_data_to_schema(data: dict, input_schema: dict) -> dict: + """Recursively coerce data values to match the types declared in the workflow input schema. + + Only top-level properties are coerced — nested structures beyond the first level + are left to the platform's own validation. + """ + properties: dict = input_schema.get("properties") or {} + result = {} + for key, value in data.items(): + prop_schema = properties.get(key, {}) + result[key] = _coerce_value(value, prop_schema) + return result + + async def _account_id_to_username(ctx: Context, account_id: str) -> str: """Retrieve the username for an account ID. @@ -181,6 +224,20 @@ async def start_workflow( if isinstance(data, str): data = jsonutils.loads(data) + # Coerce stringified values (e.g. arrays passed as strings by LLMs) using + # the workflow's declared input schema so the platform receives the right types. + if data: + workflows = await client.operations_manager.get_workflows() + input_schema = next( + ( + w.get("schema") or {} + for w in workflows + if w.get("routeName") == route_name + ), + {}, + ) + data = _coerce_data_to_schema(data, input_schema) + res = await client.operations_manager.start_workflow(route_name, data) metrics_data = res.get("metrics") or {} diff --git a/tests/test_tools_operations_manager.py b/tests/test_tools_operations_manager.py index 24c10a1..406a4e8 100644 --- a/tests/test_tools_operations_manager.py +++ b/tests/test_tools_operations_manager.py @@ -268,6 +268,138 @@ async def test_start_workflow_minimal_metrics(self, mock_context): assert result.metrics.user is None +class TestCoerceValue: + """Unit tests for _coerce_value""" + + def test_string_to_array(self): + result = operations_manager._coerce_value('["a", "b"]', {"type": "array"}) + assert result == ["a", "b"] + + def test_string_to_object(self): + result = operations_manager._coerce_value('{"key": "val"}', {"type": "object"}) + assert result == {"key": "val"} + + def test_non_json_string_unchanged(self): + result = operations_manager._coerce_value("hello", {"type": "array"}) + assert result == "hello" + + def test_single_quoted_string_unchanged(self): + # Single quotes are not valid JSON — value should pass through unmodified + result = operations_manager._coerce_value("['a', 'b']", {"type": "array"}) + assert result == "['a', 'b']" + + def test_correct_type_unchanged(self): + result = operations_manager._coerce_value(["a", "b"], {"type": "array"}) + assert result == ["a", "b"] + + def test_no_schema_type_unchanged(self): + result = operations_manager._coerce_value('["a"]', {}) + assert result == '["a"]' + + def test_type_mismatch_after_parse_unchanged(self): + # Schema says array but parsed value is an object — leave it for platform to reject + result = operations_manager._coerce_value('{"k": "v"}', {"type": "array"}) + assert result == '{"k": "v"}' + + +class TestCoerceDataToSchema: + """Unit tests for _coerce_data_to_schema""" + + def test_coerces_stringified_array(self): + schema = {"properties": {"config": {"type": "array"}}} + data = {"device": "router1", "config": '["interface eth0"]'} + result = operations_manager._coerce_data_to_schema(data, schema) + assert result["config"] == ["interface eth0"] + assert result["device"] == "router1" + + def test_unknown_keys_passed_through(self): + schema = {"properties": {}} + data = {"extra": '["x"]'} + result = operations_manager._coerce_data_to_schema(data, schema) + assert result["extra"] == '["x"]' + + def test_empty_schema_no_coercion(self): + data = {"config": '["a", "b"]'} + result = operations_manager._coerce_data_to_schema(data, {}) + assert result["config"] == '["a", "b"]' + + +class TestStartWorkflowCoercion: + """Test that start_workflow coerces data values using the workflow schema""" + + @pytest.fixture + def mock_context(self): + context = AsyncMock(spec=Context) + context.info = AsyncMock() + mock_client = MagicMock() + mock_operations_manager = AsyncMock() + mock_client.operations_manager = mock_operations_manager + context.request_context = MagicMock() + context.request_context.lifespan_context = MagicMock() + context.request_context.lifespan_context.get.return_value = mock_client + return context + + @pytest.fixture + def mock_job_response(self): + return { + "_id": "job-999", + "name": "FRR Device Config", + "description": None, + "tasks": {}, + "status": "running", + "metrics": {}, + } + + @pytest.mark.asyncio + async def test_stringified_array_coerced_before_post( + self, mock_context, mock_job_response + ): + """A config value sent as a JSON string is coerced to a list before the platform call.""" + workflow_schema = { + "type": "object", + "properties": { + "device": {"type": "string"}, + "config": {"type": "array"}, + }, + } + mock_context.request_context.lifespan_context.get.return_value.operations_manager.get_workflows.return_value = [ + {"routeName": "frr_device_cfg", "schema": workflow_schema} + ] + mock_context.request_context.lifespan_context.get.return_value.operations_manager.start_workflow.return_value = mock_job_response + + await operations_manager.start_workflow( + mock_context, + "frr_device_cfg", + { + "device": "chicago-p", + "config": '["interface eth0", "description Management"]', + }, + ) + + call_args = mock_context.request_context.lifespan_context.get.return_value.operations_manager.start_workflow.call_args + sent_data = call_args[0][1] + assert sent_data["config"] == ["interface eth0", "description Management"] + assert sent_data["device"] == "chicago-p" + + @pytest.mark.asyncio + async def test_no_matching_workflow_data_passed_through( + self, mock_context, mock_job_response + ): + """When route_name has no schema match, data is forwarded unchanged.""" + mock_context.request_context.lifespan_context.get.return_value.operations_manager.get_workflows.return_value = [] + mock_context.request_context.lifespan_context.get.return_value.operations_manager.start_workflow.return_value = mock_job_response + + await operations_manager.start_workflow( + mock_context, + "unknown_route", + {"config": '["a", "b"]'}, + ) + + call_args = mock_context.request_context.lifespan_context.get.return_value.operations_manager.start_workflow.call_args + sent_data = call_args[0][1] + assert sent_data["config"] == '["a", "b"]' + + class TestGetJobs: """Test the get_jobs tool function"""