diff --git a/services/chatbot/src/mcpserver/server.py b/services/chatbot/src/mcpserver/server.py index 8cb8998a..e0087d11 100644 --- a/services/chatbot/src/mcpserver/server.py +++ b/services/chatbot/src/mcpserver/server.py @@ -6,7 +6,7 @@ import httpx from fastmcp import FastMCP -from .tool_helpers import fix_array_responses_in_spec +from .tool_helpers import fix_array_responses_in_spec, OpenAPIRefResolver from .config import Config # Configure logging @@ -78,7 +78,8 @@ def get_http_client(): # Load your OpenAPI spec with open(Config.OPENAPI_SPEC, "r") as f: openapi_spec = json.load(f) -openapi_spec = fix_array_responses_in_spec(openapi_spec) +OpenAPIRefResolver(openapi_spec).format_openapi_spec() +fix_array_responses_in_spec(openapi_spec) # Create the MCP server mcp = FastMCP.from_openapi( diff --git a/services/chatbot/src/mcpserver/tool_helpers.py b/services/chatbot/src/mcpserver/tool_helpers.py index 38c4f306..6ef1610f 100644 --- a/services/chatbot/src/mcpserver/tool_helpers.py +++ b/services/chatbot/src/mcpserver/tool_helpers.py @@ -1,5 +1,4 @@ import os - from chatbot.extensions import db @@ -25,5 +24,71 @@ def fix_array_responses_in_spec(spec): if schema.get("type") == "array": del media["schema"] + +class OpenAPIRefResolver: + def __init__(self, spec): + self.spec = spec + self.components = spec.get("components", {}).get("schemas", {}) + + def resolve_ref(self, ref): + if not ref.startswith("#/components/schemas/"): + return None + + schema_name = ref.split("/")[-1] + if schema_name not in self.components: + return None + + return self.components[schema_name] + + def inline_all_refs(self, schema, visited=None): + if visited is None: + visited = set() + + if isinstance(schema, dict): + if "$ref" in schema: + ref = schema["$ref"] + if ref.startswith("#/components/schemas/"): + schema_name = ref.split("/")[-1] + + if schema_name in visited: + return {"type": "object", "description": f"Circular reference to {schema_name}"} + + visited.add(schema_name) + resolved = self.resolve_ref(ref) + if resolved: + inlined = self.inline_all_refs(resolved, visited.copy()) + visited.discard(schema_name) + return inlined + else: + return schema + else: + return schema + else: + return {key: self.inline_all_refs(value, visited) for key, value in schema.items()} + elif isinstance(schema, list): + return [self.inline_all_refs(item, visited) for item in schema] + else: + return schema + + def process_schema_recursively(self, schema): + return self.inline_all_refs(schema) - return spec \ No newline at end of file + def format_openapi_spec(self): + for path_item in self.spec.get("paths", {}).values(): + for method, operation in path_item.items(): + if method in ["get", "post", "put", "patch", "delete", "options", "head", "trace"]: + if "requestBody" in operation: + content = operation["requestBody"].get("content", {}) + for media_obj in content.values(): + if "schema" in media_obj: + media_obj["schema"] = self.process_schema_recursively(media_obj["schema"]) + + for response in operation.get("responses", {}).values(): + content = response.get("content", {}) + for media_obj in content.values(): + if "schema" in media_obj: + media_obj["schema"] = self.process_schema_recursively(media_obj["schema"]) + + if "components" in self.spec and "schemas" in self.spec["components"]: + for schema_name, schema_def in self.spec["components"]["schemas"].items(): + self.spec["components"]["schemas"][schema_name] = self.process_schema_recursively(schema_def) \ No newline at end of file