Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions services/chatbot/src/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import httpx
from fastmcp import FastMCP

from .tool_helpers import fix_array_responses_in_spec
from .tool_helpers import fix_array_responses_in_spec, OpenAPIRefResolver
from .config import Config

# Configure logging
Expand Down Expand Up @@ -78,7 +78,8 @@ def get_http_client():
# Load your OpenAPI spec
with open(Config.OPENAPI_SPEC, "r") as f:
openapi_spec = json.load(f)
openapi_spec = fix_array_responses_in_spec(openapi_spec)
OpenAPIRefResolver(openapi_spec).format_openapi_spec()
fix_array_responses_in_spec(openapi_spec)

# Create the MCP server
mcp = FastMCP.from_openapi(
Expand Down
69 changes: 67 additions & 2 deletions services/chatbot/src/mcpserver/tool_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os

from chatbot.extensions import db


Expand All @@ -25,5 +24,71 @@ def fix_array_responses_in_spec(spec):

if schema.get("type") == "array":
del media["schema"]

class OpenAPIRefResolver:
def __init__(self, spec):
self.spec = spec
self.components = spec.get("components", {}).get("schemas", {})

def resolve_ref(self, ref):
if not ref.startswith("#/components/schemas/"):
return None

schema_name = ref.split("/")[-1]
if schema_name not in self.components:
return None

return self.components[schema_name]

def inline_all_refs(self, schema, visited=None):
if visited is None:
visited = set()

if isinstance(schema, dict):
if "$ref" in schema:
ref = schema["$ref"]
if ref.startswith("#/components/schemas/"):
schema_name = ref.split("/")[-1]

if schema_name in visited:
return {"type": "object", "description": f"Circular reference to {schema_name}"}

visited.add(schema_name)
resolved = self.resolve_ref(ref)
if resolved:
inlined = self.inline_all_refs(resolved, visited.copy())
visited.discard(schema_name)
return inlined
else:
return schema
else:
return schema
else:
return {key: self.inline_all_refs(value, visited) for key, value in schema.items()}
elif isinstance(schema, list):
return [self.inline_all_refs(item, visited) for item in schema]
else:
return schema

def process_schema_recursively(self, schema):
return self.inline_all_refs(schema)

return spec
def format_openapi_spec(self):
for path_item in self.spec.get("paths", {}).values():
for method, operation in path_item.items():
if method in ["get", "post", "put", "patch", "delete", "options", "head", "trace"]:
if "requestBody" in operation:
content = operation["requestBody"].get("content", {})
for media_obj in content.values():
if "schema" in media_obj:
media_obj["schema"] = self.process_schema_recursively(media_obj["schema"])

for response in operation.get("responses", {}).values():
content = response.get("content", {})
for media_obj in content.values():
if "schema" in media_obj:
media_obj["schema"] = self.process_schema_recursively(media_obj["schema"])

if "components" in self.spec and "schemas" in self.spec["components"]:
for schema_name, schema_def in self.spec["components"]["schemas"].items():
self.spec["components"]["schemas"][schema_name] = self.process_schema_recursively(schema_def)
Loading