From 81f2dfd437d2f1468fb2955ef41fb554cd323497 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 00:56:53 +0000 Subject: [PATCH 1/2] Initial plan From b284c5fbc3d3cb496e99072b7dcc99564bf2b81a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 25 Feb 2026 01:07:29 +0000 Subject: [PATCH 2/2] =?UTF-8?q?refactor:=20semantic=20function=20clusterin?= =?UTF-8?q?g=20=E2=80=94=20split=20connection.go=20and=20move=20misplaced?= =?UTF-8?q?=20server=20helpers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: lpcox <15877973+lpcox@users.noreply.github.com> --- internal/mcp/connection.go | 481 ------------------------------- internal/mcp/http_transport.go | 496 ++++++++++++++++++++++++++++++++ internal/server/auth.go | 9 + internal/server/http_helpers.go | 12 + internal/server/transport.go | 21 -- 5 files changed, 517 insertions(+), 502 deletions(-) create mode 100644 internal/mcp/http_transport.go diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index 86c8adf7..f0374fea 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -8,117 +8,19 @@ import ( "fmt" "io" "log" - "log/slog" "net/http" "os/exec" "strings" - "sync/atomic" "time" "github.com/github/gh-aw-mcpg/internal/dockerutil" "github.com/github/gh-aw-mcpg/internal/logger" "github.com/github/gh-aw-mcpg/internal/logger/sanitize" - "github.com/github/gh-aw-mcpg/internal/strutil" - "github.com/github/gh-aw-mcpg/internal/version" sdk "github.com/modelcontextprotocol/go-sdk/mcp" ) var logConn = logger.New("mcp:connection") -// isHTTPConnectionError checks if an error is a network connection error -// This helper reduces code duplication for checking common connection error patterns. -// Note: Uses string matching which is fragile but consistent with existing patterns in the codebase. -// TODO: Consider using errors.Is() or type assertions (*net.OpError) for more robust error classification. -func isHTTPConnectionError(err error) bool { - if err == nil { - return false - } - errMsg := err.Error() - return strings.Contains(errMsg, "connection refused") || - strings.Contains(errMsg, "no such host") || - strings.Contains(errMsg, "network is unreachable") -} - -// parseSSEResponse extracts JSON data from SSE-formatted response -// SSE format: "event: message\ndata: {json}\n\n" -func parseSSEResponse(body []byte) ([]byte, error) { - lines := strings.Split(string(body), "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "data: ") { - jsonData := strings.TrimPrefix(line, "data: ") - return []byte(jsonData), nil - } - } - return nil, fmt.Errorf("no data field found in SSE response") -} - -// parseJSONRPCResponseWithSSE parses a JSON-RPC response that might be in SSE format. -// This helper consolidates duplicate response parsing logic that appears in multiple places. -// -// The function tries to parse the body as JSON first. If that fails, it attempts to extract -// JSON from SSE format (event: message\ndata: {...}). This handles backends that return -// responses in Server-Sent Events format. -// -// Parameters: -// - body: Raw response body bytes -// - statusCode: HTTP status code (used for enhanced error messages) -// - contextDesc: Description of the calling context (e.g., "initialize response", "JSON-RPC response") -// -// Returns: -// - *Response: Parsed JSON-RPC response on success -// - error: Detailed parsing error with body preview on failure -func parseJSONRPCResponseWithSSE(body []byte, statusCode int, contextDesc string) (*Response, error) { - var rpcResponse Response - httpErrorResponse := func() *Response { - return &Response{ - JSONRPC: "2.0", - Error: &ResponseError{ - Code: -32603, // Internal error - Message: fmt.Sprintf("HTTP %d: %s", statusCode, http.StatusText(statusCode)), - Data: json.RawMessage(body), - }, - } - } - - // Try parsing as standard JSON first - if err := json.Unmarshal(body, &rpcResponse); err != nil { - // Try parsing as SSE format - logConn.Printf("Initial JSON parse failed, attempting SSE format parsing") - sseData, sseErr := parseSSEResponse(body) - if sseErr != nil { - // If we have a non-OK HTTP status and can't parse the response, - // construct a JSON-RPC error response with HTTP error details - if statusCode != http.StatusOK { - logConn.Printf("HTTP error status=%d, body cannot be parsed as JSON-RPC", statusCode) - return httpErrorResponse(), nil - } - // Include the response body to help debug what the server actually returned - bodyPreview := strutil.TruncateWithSuffix(string(body), 500, "... (truncated)") - return nil, fmt.Errorf("failed to parse %s (received non-JSON or malformed response): %w\nResponse body: %s", contextDesc, sseErr, bodyPreview) - } - - // Successfully extracted JSON from SSE, now parse it - if err := json.Unmarshal(sseData, &rpcResponse); err != nil { - // If we have a non-OK HTTP status and can't parse the SSE data, - // construct a JSON-RPC error response with HTTP error details - if statusCode != http.StatusOK { - logConn.Printf("HTTP error status=%d, SSE data cannot be parsed as JSON-RPC", statusCode) - return httpErrorResponse(), nil - } - return nil, fmt.Errorf("failed to parse JSON data extracted from SSE response: %w\nJSON data: %s", err, string(sseData)) - } - logConn.Printf("Successfully parsed SSE-formatted response") - } - - if statusCode != http.StatusOK { - logConn.Printf("HTTP error status=%d, returning synthetic JSON-RPC error response", statusCode) - return httpErrorResponse(), nil - } - - return &rpcResponse, nil -} - // ContextKey for session ID type ContextKey string @@ -126,21 +28,6 @@ type ContextKey string // This is the same key used in the server package to avoid circular dependencies const SessionIDContextKey ContextKey = "awmg-session-id" -// requestIDCounter is used to generate unique request IDs for HTTP requests -var requestIDCounter uint64 - -// HTTPTransportType represents the type of HTTP transport being used -type HTTPTransportType string - -const ( - // HTTPTransportStreamable uses the streamable HTTP transport (2025-03-26 spec) - HTTPTransportStreamable HTTPTransportType = "streamable" - // HTTPTransportSSE uses the SSE transport (2024-11-05 spec) - HTTPTransportSSE HTTPTransportType = "sse" - // HTTPTransportPlainJSON uses plain JSON-RPC 2.0 over HTTP POST (non-standard) - HTTPTransportPlainJSON HTTPTransportType = "plain-json" -) - // Connection represents a connection to an MCP server using the official SDK type Connection struct { client *sdk.Client @@ -157,150 +44,6 @@ type Connection struct { httpTransportType HTTPTransportType // Type of HTTP transport in use } -// newMCPClient creates a new MCP SDK client with standard implementation details -// Pass nil for logger parameter to disable SDK logging (for tests) -func newMCPClient(log *logger.Logger) *sdk.Client { - var slogLogger *slog.Logger - if log != nil { - slogLogger = logger.NewSlogLoggerWithHandler(log) - } - return sdk.NewClient(&sdk.Implementation{ - Name: "awmg", - Version: version.Get(), - }, &sdk.ClientOptions{ - Logger: slogLogger, - }) -} - -// newHTTPConnection creates a new HTTP Connection struct with common fields -func newHTTPConnection(ctx context.Context, cancel context.CancelFunc, client *sdk.Client, session *sdk.ClientSession, url string, headers map[string]string, httpClient *http.Client, transportType HTTPTransportType, serverID string) *Connection { - // Extract session ID from SDK session if available - var sessionID string - if session != nil { - sessionID = session.ID() - } - return &Connection{ - client: client, - session: session, - ctx: ctx, - cancel: cancel, - serverID: serverID, - isHTTP: true, - httpURL: url, - headers: headers, - httpClient: httpClient, - httpTransportType: transportType, - httpSessionID: sessionID, - } -} - -// createJSONRPCRequest creates a JSON-RPC 2.0 request map -func createJSONRPCRequest(requestID uint64, method string, params interface{}) map[string]interface{} { - return map[string]interface{}{ - "jsonrpc": "2.0", - "id": requestID, - "method": method, - "params": params, - } -} - -// ensureToolCallArguments ensures that the arguments field exists in tools/call params -// The MCP protocol requires the arguments field to always be present, even if empty -func ensureToolCallArguments(params interface{}) interface{} { - // Convert params to map if it isn't already - paramsMap, ok := params.(map[string]interface{}) - if !ok { - // If params isn't a map, return as-is (this shouldn't happen for tools/call) - return params - } - - // Check if arguments field exists - if _, hasArgs := paramsMap["arguments"]; !hasArgs { - // Add empty arguments map if missing - paramsMap["arguments"] = make(map[string]interface{}) - } else if paramsMap["arguments"] == nil { - // Replace nil with empty map - paramsMap["arguments"] = make(map[string]interface{}) - } - - return paramsMap -} - -// setupHTTPRequest creates and configures an HTTP request with standard headers -func setupHTTPRequest(ctx context.Context, url string, requestBody []byte, headers map[string]string) (*http.Request, error) { - httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP request: %w", err) - } - - // Set standard headers - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("Accept", "application/json, text/event-stream") - - // Add configured headers (e.g., Authorization) - for key, value := range headers { - httpReq.Header.Set(key, value) - } - - return httpReq, nil -} - -// httpRequestResult contains the result of an HTTP request execution -type httpRequestResult struct { - StatusCode int - ResponseBody []byte - Header http.Header -} - -// executeHTTPRequest executes an HTTP JSON-RPC request and returns the response details. -// This helper consolidates the common pattern of: create request → marshal → setup HTTP → execute → read response. -// It handles connection errors consistently and provides method-specific error messages. -// The headerModifier function allows callers to modify headers before the request is sent. -func (c *Connection) executeHTTPRequest(ctx context.Context, method string, params interface{}, requestID uint64, headerModifier func(*http.Request)) (*httpRequestResult, error) { - // Create JSON-RPC request - request := createJSONRPCRequest(requestID, method, params) - - // Marshal request body - requestBody, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal %s request: %w", method, err) - } - - // Create HTTP request with standard headers - httpReq, err := setupHTTPRequest(ctx, c.httpURL, requestBody, c.headers) - if err != nil { - return nil, err - } - - // Allow caller to modify headers (e.g., add session ID) - if headerModifier != nil { - headerModifier(httpReq) - } - - // Execute HTTP request - httpResp, err := c.httpClient.Do(httpReq) - if err != nil { - // Check if it's a connection error (cannot connect at all) - if isHTTPConnectionError(err) { - return nil, fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err) - } - return nil, fmt.Errorf("%s HTTP request failed: %w", method, err) - } - defer httpResp.Body.Close() - - // Read response body - responseBody, err := io.ReadAll(httpResp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read %s response: %w", method, err) - } - - return &httpRequestResult{ - StatusCode: httpResp.StatusCode, - ResponseBody: responseBody, - Header: httpResp.Header, - }, nil -} - // NewConnection creates a new MCP connection using the official SDK func NewConnection(ctx context.Context, serverID, command string, args []string, env map[string]string) (*Connection, error) { logger.LogInfo("backend", "Creating new MCP backend connection, command=%s, args=%v", command, sanitize.SanitizeArgs(args)) @@ -497,104 +240,6 @@ func NewHTTPConnection(ctx context.Context, serverID, url string, headers map[st return nil, fmt.Errorf("failed to connect using any HTTP transport (tried streamable, SSE, and plain JSON-RPC): last error: %w", err) } -// transportConnector is a function that creates an SDK transport for a given URL and HTTP client -type transportConnector func(url string, httpClient *http.Client) sdk.Transport - -// trySDKTransport is a generic function to attempt connection with any SDK-based transport -// It handles the common logic of creating a client, connecting with timeout, and returning a connection -func trySDKTransport( - ctx context.Context, - cancel context.CancelFunc, - serverID string, - url string, - headers map[string]string, - httpClient *http.Client, - transportType HTTPTransportType, - transportName string, - createTransport transportConnector, -) (*Connection, error) { - // Create MCP client with logger - client := newMCPClient(logConn) - - // Create transport using the provided connector - transport := createTransport(url, httpClient) - - // Try to connect with a timeout - this will fail if the server doesn't support this transport - // Use a short timeout to fail fast and try other transports - connectCtx, connectCancel := context.WithTimeout(ctx, 5*time.Second) - defer connectCancel() - - session, err := client.Connect(connectCtx, transport, nil) - if err != nil { - return nil, fmt.Errorf("%s transport connect failed: %w", transportName, err) - } - - conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, transportType, serverID) - - logger.LogInfo("backend", "%s transport connected successfully", transportName) - logConn.Printf("Connected with %s transport", transportName) - return conn, nil -} - -// tryStreamableHTTPTransport attempts to connect using the streamable HTTP transport (2025-03-26 spec) -func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, serverID, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { - return trySDKTransport( - ctx, cancel, serverID, url, headers, httpClient, - HTTPTransportStreamable, - "streamable HTTP", - func(url string, httpClient *http.Client) sdk.Transport { - return &sdk.StreamableClientTransport{ - Endpoint: url, - HTTPClient: httpClient, - MaxRetries: 0, // Don't retry on failure - we'll try other transports - } - }, - ) -} - -// trySSETransport attempts to connect using the SSE transport (2024-11-05 spec) -func trySSETransport(ctx context.Context, cancel context.CancelFunc, serverID, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { - return trySDKTransport( - ctx, cancel, serverID, url, headers, httpClient, - HTTPTransportSSE, - "SSE", - func(url string, httpClient *http.Client) sdk.Transport { - return &sdk.SSEClientTransport{ - Endpoint: url, - HTTPClient: httpClient, - } - }, - ) -} - -// tryPlainJSONTransport attempts to connect using plain JSON-RPC 2.0 over HTTP POST (non-standard) -// This is used for compatibility with servers like safeinputs that don't implement standard MCP HTTP transports -func tryPlainJSONTransport(ctx context.Context, cancel context.CancelFunc, serverID, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { - conn := &Connection{ - ctx: ctx, - cancel: cancel, - serverID: serverID, - isHTTP: true, - httpURL: url, - headers: headers, - httpClient: httpClient, - httpTransportType: HTTPTransportPlainJSON, - } - - // Send initialize request to establish a session with the HTTP backend - // This is critical for backends that require session management - logConn.Printf("Sending initialize request via plain JSON-RPC to: %s", url) - sessionID, err := conn.initializeHTTPSession() - if err != nil { - return nil, fmt.Errorf("plain JSON-RPC initialize failed: %w", err) - } - - conn.httpSessionID = sessionID - logger.LogInfo("backend", "Plain JSON-RPC transport connected successfully with session=%s", sessionID) - logConn.Printf("Connected with plain JSON-RPC transport, session=%s", sessionID) - return conn, nil -} - // IsHTTP returns true if this is an HTTP connection func (c *Connection) IsHTTP() bool { return c.isHTTP @@ -689,132 +334,6 @@ func (c *Connection) callSDKMethod(method string, params interface{}) (*Response } } -// initializeHTTPSession sends an initialize request to the HTTP backend and captures the session ID -func (c *Connection) initializeHTTPSession() (string, error) { - // Generate unique request ID - requestID := atomic.AddUint64(&requestIDCounter, 1) - - // Create initialize request with MCP protocol parameters - initParams := map[string]interface{}{ - "protocolVersion": "2024-11-05", - "capabilities": map[string]interface{}{}, - "clientInfo": map[string]interface{}{ - "name": "awmg", - "version": "1.0.0", - }, - } - - logConn.Printf("Sending initialize request") - - // Generate a temporary session ID for the initialize request - // Some backends may require this header even during initialization - tempSessionID := fmt.Sprintf("awmg-init-%d", requestID) - - // Execute HTTP request with custom header modification - result, err := c.executeHTTPRequest(context.Background(), "initialize", initParams, requestID, func(httpReq *http.Request) { - httpReq.Header.Set("Mcp-Session-Id", tempSessionID) - logConn.Printf("Sending initialize with temporary session ID: %s", tempSessionID) - }) - if err != nil { - return "", err - } - - // Capture the Mcp-Session-Id from response headers - sessionID := result.Header.Get("Mcp-Session-Id") - if sessionID != "" { - logConn.Printf("Captured Mcp-Session-Id from response: %s", sessionID) - } else { - // If no session ID in response, use the temporary one - // This handles backends that don't return a session ID - sessionID = tempSessionID - logConn.Printf("No Mcp-Session-Id in response, using temporary session ID: %s", sessionID) - } - - logConn.Printf("Initialize response: status=%d, body_len=%d, session=%s", result.StatusCode, len(result.ResponseBody), sessionID) - - // Check for HTTP errors - if result.StatusCode != http.StatusOK { - return "", fmt.Errorf("initialize failed: status=%d, body=%s", result.StatusCode, string(result.ResponseBody)) - } - - // Parse JSON-RPC response to check for errors - // The response might be in SSE format (event: message\ndata: {...}) - rpcResponse, err := parseJSONRPCResponseWithSSE(result.ResponseBody, result.StatusCode, "initialize response") - if err != nil { - return "", err - } - - if rpcResponse.Error != nil { - return "", fmt.Errorf("initialize error: code=%d, message=%s", rpcResponse.Error.Code, rpcResponse.Error.Message) - } - - return sessionID, nil -} - -// sendHTTPRequest sends a JSON-RPC request to an HTTP MCP server -// The ctx parameter is used to extract session ID for the Mcp-Session-Id header -func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params interface{}) (*Response, error) { - // Generate unique request ID using atomic counter - requestID := atomic.AddUint64(&requestIDCounter, 1) - - // For tools/call, ensure arguments field always exists (MCP protocol requirement) - if method == "tools/call" { - params = ensureToolCallArguments(params) - } - - logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID) - - // Execute HTTP request with custom header modification for session ID - result, err := c.executeHTTPRequest(ctx, method, params, requestID, func(httpReq *http.Request) { - // Add Mcp-Session-Id header with priority: - // 1) Context session ID (if explicitly provided for this request) - // 2) Stored httpSessionID from initialization - var sessionID string - if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" { - sessionID = ctxSessionID - logConn.Printf("Using session ID from context: %s", sessionID) - } else if c.httpSessionID != "" { - sessionID = c.httpSessionID - logConn.Printf("Using stored session ID from initialization: %s", sessionID) - } - - if sessionID != "" { - httpReq.Header.Set("Mcp-Session-Id", sessionID) - } else { - logConn.Printf("No session ID available (backend may not require session management)") - } - }) - if err != nil { - return nil, err - } - - logConn.Printf("Received HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody)) - - // Parse JSON-RPC response - // The response might be in SSE format (event: message\ndata: {...}) - rpcResponse, err := parseJSONRPCResponseWithSSE(result.ResponseBody, result.StatusCode, "JSON-RPC response") - if err != nil { - return nil, err - } - - // Check for HTTP errors after parsing - // If we have a non-OK status but successfully parsed a JSON-RPC response, - // pass it through (it may already contain an error field) - if result.StatusCode != http.StatusOK { - logConn.Printf("HTTP error status=%d with valid JSON-RPC response, passing through", result.StatusCode) - // If the response doesn't already have an error, construct one - if rpcResponse.Error == nil { - rpcResponse.Error = &ResponseError{ - Code: -32603, // Internal error - Message: fmt.Sprintf("HTTP %d: %s", result.StatusCode, http.StatusText(result.StatusCode)), - Data: result.ResponseBody, - } - } - } - - return rpcResponse, nil -} - // marshalToResponse marshals an SDK result into a Response object. // This helper reduces code duplication across all MCP method wrappers. func marshalToResponse(result interface{}) (*Response, error) { diff --git a/internal/mcp/http_transport.go b/internal/mcp/http_transport.go new file mode 100644 index 00000000..5be5cb6b --- /dev/null +++ b/internal/mcp/http_transport.go @@ -0,0 +1,496 @@ +package mcp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "sync/atomic" + "time" + + "github.com/github/gh-aw-mcpg/internal/logger" + "github.com/github/gh-aw-mcpg/internal/strutil" + "github.com/github/gh-aw-mcpg/internal/version" + sdk "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// HTTPTransportType represents the type of HTTP transport being used +type HTTPTransportType string + +const ( + // HTTPTransportStreamable uses the streamable HTTP transport (2025-03-26 spec) + HTTPTransportStreamable HTTPTransportType = "streamable" + // HTTPTransportSSE uses the SSE transport (2024-11-05 spec) + HTTPTransportSSE HTTPTransportType = "sse" + // HTTPTransportPlainJSON uses plain JSON-RPC 2.0 over HTTP POST (non-standard) + HTTPTransportPlainJSON HTTPTransportType = "plain-json" +) + +// requestIDCounter is used to generate unique request IDs for HTTP requests +var requestIDCounter uint64 + +// httpRequestResult contains the result of an HTTP request execution +type httpRequestResult struct { + StatusCode int + ResponseBody []byte + Header http.Header +} + +// transportConnector is a function that creates an SDK transport for a given URL and HTTP client +type transportConnector func(url string, httpClient *http.Client) sdk.Transport + +// isHTTPConnectionError checks if an error is a network connection error +// This helper reduces code duplication for checking common connection error patterns. +// Note: Uses string matching which is fragile but consistent with existing patterns in the codebase. +// TODO: Consider using errors.Is() or type assertions (*net.OpError) for more robust error classification. +func isHTTPConnectionError(err error) bool { + if err == nil { + return false + } + errMsg := err.Error() + return strings.Contains(errMsg, "connection refused") || + strings.Contains(errMsg, "no such host") || + strings.Contains(errMsg, "network is unreachable") +} + +// parseSSEResponse extracts JSON data from SSE-formatted response +// SSE format: "event: message\ndata: {json}\n\n" +func parseSSEResponse(body []byte) ([]byte, error) { + lines := strings.Split(string(body), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "data: ") { + jsonData := strings.TrimPrefix(line, "data: ") + return []byte(jsonData), nil + } + } + return nil, fmt.Errorf("no data field found in SSE response") +} + +// parseJSONRPCResponseWithSSE parses a JSON-RPC response that might be in SSE format. +// This helper consolidates duplicate response parsing logic that appears in multiple places. +// +// The function tries to parse the body as JSON first. If that fails, it attempts to extract +// JSON from SSE format (event: message\ndata: {...}). This handles backends that return +// responses in Server-Sent Events format. +// +// Parameters: +// - body: Raw response body bytes +// - statusCode: HTTP status code (used for enhanced error messages) +// - contextDesc: Description of the calling context (e.g., "initialize response", "JSON-RPC response") +// +// Returns: +// - *Response: Parsed JSON-RPC response on success +// - error: Detailed parsing error with body preview on failure +func parseJSONRPCResponseWithSSE(body []byte, statusCode int, contextDesc string) (*Response, error) { + var rpcResponse Response + httpErrorResponse := func() *Response { + return &Response{ + JSONRPC: "2.0", + Error: &ResponseError{ + Code: -32603, // Internal error + Message: fmt.Sprintf("HTTP %d: %s", statusCode, http.StatusText(statusCode)), + Data: json.RawMessage(body), + }, + } + } + + // Try parsing as standard JSON first + if err := json.Unmarshal(body, &rpcResponse); err != nil { + // Try parsing as SSE format + logConn.Printf("Initial JSON parse failed, attempting SSE format parsing") + sseData, sseErr := parseSSEResponse(body) + if sseErr != nil { + // If we have a non-OK HTTP status and can't parse the response, + // construct a JSON-RPC error response with HTTP error details + if statusCode != http.StatusOK { + logConn.Printf("HTTP error status=%d, body cannot be parsed as JSON-RPC", statusCode) + return httpErrorResponse(), nil + } + // Include the response body to help debug what the server actually returned + bodyPreview := strutil.TruncateWithSuffix(string(body), 500, "... (truncated)") + return nil, fmt.Errorf("failed to parse %s (received non-JSON or malformed response): %w\nResponse body: %s", contextDesc, sseErr, bodyPreview) + } + + // Successfully extracted JSON from SSE, now parse it + if err := json.Unmarshal(sseData, &rpcResponse); err != nil { + // If we have a non-OK HTTP status and can't parse the SSE data, + // construct a JSON-RPC error response with HTTP error details + if statusCode != http.StatusOK { + logConn.Printf("HTTP error status=%d, SSE data cannot be parsed as JSON-RPC", statusCode) + return httpErrorResponse(), nil + } + return nil, fmt.Errorf("failed to parse JSON data extracted from SSE response: %w\nJSON data: %s", err, string(sseData)) + } + logConn.Printf("Successfully parsed SSE-formatted response") + } + + if statusCode != http.StatusOK { + logConn.Printf("HTTP error status=%d, returning synthetic JSON-RPC error response", statusCode) + return httpErrorResponse(), nil + } + + return &rpcResponse, nil +} + +// newMCPClient creates a new MCP SDK client with standard implementation details +// Pass nil for logger parameter to disable SDK logging (for tests) +func newMCPClient(log *logger.Logger) *sdk.Client { + var slogLogger *slog.Logger + if log != nil { + slogLogger = logger.NewSlogLoggerWithHandler(log) + } + return sdk.NewClient(&sdk.Implementation{ + Name: "awmg", + Version: version.Get(), + }, &sdk.ClientOptions{ + Logger: slogLogger, + }) +} + +// newHTTPConnection creates a new HTTP Connection struct with common fields +func newHTTPConnection(ctx context.Context, cancel context.CancelFunc, client *sdk.Client, session *sdk.ClientSession, url string, headers map[string]string, httpClient *http.Client, transportType HTTPTransportType, serverID string) *Connection { + // Extract session ID from SDK session if available + var sessionID string + if session != nil { + sessionID = session.ID() + } + return &Connection{ + client: client, + session: session, + ctx: ctx, + cancel: cancel, + serverID: serverID, + isHTTP: true, + httpURL: url, + headers: headers, + httpClient: httpClient, + httpTransportType: transportType, + httpSessionID: sessionID, + } +} + +// createJSONRPCRequest creates a JSON-RPC 2.0 request map +func createJSONRPCRequest(requestID uint64, method string, params interface{}) map[string]interface{} { + return map[string]interface{}{ + "jsonrpc": "2.0", + "id": requestID, + "method": method, + "params": params, + } +} + +// ensureToolCallArguments ensures that the arguments field exists in tools/call params +// The MCP protocol requires the arguments field to always be present, even if empty +func ensureToolCallArguments(params interface{}) interface{} { + // Convert params to map if it isn't already + paramsMap, ok := params.(map[string]interface{}) + if !ok { + // If params isn't a map, return as-is (this shouldn't happen for tools/call) + return params + } + + // Check if arguments field exists + if _, hasArgs := paramsMap["arguments"]; !hasArgs { + // Add empty arguments map if missing + paramsMap["arguments"] = make(map[string]interface{}) + } else if paramsMap["arguments"] == nil { + // Replace nil with empty map + paramsMap["arguments"] = make(map[string]interface{}) + } + + return paramsMap +} + +// setupHTTPRequest creates and configures an HTTP request with standard headers +func setupHTTPRequest(ctx context.Context, url string, requestBody []byte, headers map[string]string) (*http.Request, error) { + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Set standard headers + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json, text/event-stream") + + // Add configured headers (e.g., Authorization) + for key, value := range headers { + httpReq.Header.Set(key, value) + } + + return httpReq, nil +} + +// executeHTTPRequest executes an HTTP JSON-RPC request and returns the response details. +// This helper consolidates the common pattern of: create request → marshal → setup HTTP → execute → read response. +// It handles connection errors consistently and provides method-specific error messages. +// The headerModifier function allows callers to modify headers before the request is sent. +func (c *Connection) executeHTTPRequest(ctx context.Context, method string, params interface{}, requestID uint64, headerModifier func(*http.Request)) (*httpRequestResult, error) { + // Create JSON-RPC request + request := createJSONRPCRequest(requestID, method, params) + + // Marshal request body + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal %s request: %w", method, err) + } + + // Create HTTP request with standard headers + httpReq, err := setupHTTPRequest(ctx, c.httpURL, requestBody, c.headers) + if err != nil { + return nil, err + } + + // Allow caller to modify headers (e.g., add session ID) + if headerModifier != nil { + headerModifier(httpReq) + } + + // Execute HTTP request + httpResp, err := c.httpClient.Do(httpReq) + if err != nil { + // Check if it's a connection error (cannot connect at all) + if isHTTPConnectionError(err) { + return nil, fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err) + } + return nil, fmt.Errorf("%s HTTP request failed: %w", method, err) + } + defer httpResp.Body.Close() + + // Read response body + responseBody, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read %s response: %w", method, err) + } + + return &httpRequestResult{ + StatusCode: httpResp.StatusCode, + ResponseBody: responseBody, + Header: httpResp.Header, + }, nil +} + +// trySDKTransport is a generic function to attempt connection with any SDK-based transport +// It handles the common logic of creating a client, connecting with timeout, and returning a connection +func trySDKTransport( + ctx context.Context, + cancel context.CancelFunc, + serverID string, + url string, + headers map[string]string, + httpClient *http.Client, + transportType HTTPTransportType, + transportName string, + createTransport transportConnector, +) (*Connection, error) { + // Create MCP client with logger + client := newMCPClient(logConn) + + // Create transport using the provided connector + transport := createTransport(url, httpClient) + + // Try to connect with a timeout - this will fail if the server doesn't support this transport + // Use a short timeout to fail fast and try other transports + connectCtx, connectCancel := context.WithTimeout(ctx, 5*time.Second) + defer connectCancel() + + session, err := client.Connect(connectCtx, transport, nil) + if err != nil { + return nil, fmt.Errorf("%s transport connect failed: %w", transportName, err) + } + + conn := newHTTPConnection(ctx, cancel, client, session, url, headers, httpClient, transportType, serverID) + + logger.LogInfo("backend", "%s transport connected successfully", transportName) + logConn.Printf("Connected with %s transport", transportName) + return conn, nil +} + +// tryStreamableHTTPTransport attempts to connect using the streamable HTTP transport (2025-03-26 spec) +func tryStreamableHTTPTransport(ctx context.Context, cancel context.CancelFunc, serverID, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { + return trySDKTransport( + ctx, cancel, serverID, url, headers, httpClient, + HTTPTransportStreamable, + "streamable HTTP", + func(url string, httpClient *http.Client) sdk.Transport { + return &sdk.StreamableClientTransport{ + Endpoint: url, + HTTPClient: httpClient, + MaxRetries: 0, // Don't retry on failure - we'll try other transports + } + }, + ) +} + +// trySSETransport attempts to connect using the SSE transport (2024-11-05 spec) +func trySSETransport(ctx context.Context, cancel context.CancelFunc, serverID, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { + return trySDKTransport( + ctx, cancel, serverID, url, headers, httpClient, + HTTPTransportSSE, + "SSE", + func(url string, httpClient *http.Client) sdk.Transport { + return &sdk.SSEClientTransport{ + Endpoint: url, + HTTPClient: httpClient, + } + }, + ) +} + +// tryPlainJSONTransport attempts to connect using plain JSON-RPC 2.0 over HTTP POST (non-standard) +// This is used for compatibility with servers like safeinputs that don't implement standard MCP HTTP transports +func tryPlainJSONTransport(ctx context.Context, cancel context.CancelFunc, serverID, url string, headers map[string]string, httpClient *http.Client) (*Connection, error) { + conn := &Connection{ + ctx: ctx, + cancel: cancel, + serverID: serverID, + isHTTP: true, + httpURL: url, + headers: headers, + httpClient: httpClient, + httpTransportType: HTTPTransportPlainJSON, + } + + // Send initialize request to establish a session with the HTTP backend + // This is critical for backends that require session management + logConn.Printf("Sending initialize request via plain JSON-RPC to: %s", url) + sessionID, err := conn.initializeHTTPSession() + if err != nil { + return nil, fmt.Errorf("plain JSON-RPC initialize failed: %w", err) + } + + conn.httpSessionID = sessionID + logger.LogInfo("backend", "Plain JSON-RPC transport connected successfully with session=%s", sessionID) + logConn.Printf("Connected with plain JSON-RPC transport, session=%s", sessionID) + return conn, nil +} + +// initializeHTTPSession sends an initialize request to the HTTP backend and captures the session ID +func (c *Connection) initializeHTTPSession() (string, error) { + // Generate unique request ID + requestID := atomic.AddUint64(&requestIDCounter, 1) + + // Create initialize request with MCP protocol parameters + initParams := map[string]interface{}{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]interface{}{}, + "clientInfo": map[string]interface{}{ + "name": "awmg", + "version": "1.0.0", + }, + } + + logConn.Printf("Sending initialize request") + + // Generate a temporary session ID for the initialize request + // Some backends may require this header even during initialization + tempSessionID := fmt.Sprintf("awmg-init-%d", requestID) + + // Execute HTTP request with custom header modification + result, err := c.executeHTTPRequest(context.Background(), "initialize", initParams, requestID, func(httpReq *http.Request) { + httpReq.Header.Set("Mcp-Session-Id", tempSessionID) + logConn.Printf("Sending initialize with temporary session ID: %s", tempSessionID) + }) + if err != nil { + return "", err + } + + // Capture the Mcp-Session-Id from response headers + sessionID := result.Header.Get("Mcp-Session-Id") + if sessionID != "" { + logConn.Printf("Captured Mcp-Session-Id from response: %s", sessionID) + } else { + // If no session ID in response, use the temporary one + // This handles backends that don't return a session ID + sessionID = tempSessionID + logConn.Printf("No Mcp-Session-Id in response, using temporary session ID: %s", sessionID) + } + + logConn.Printf("Initialize response: status=%d, body_len=%d, session=%s", result.StatusCode, len(result.ResponseBody), sessionID) + + // Check for HTTP errors + if result.StatusCode != http.StatusOK { + return "", fmt.Errorf("initialize failed: status=%d, body=%s", result.StatusCode, string(result.ResponseBody)) + } + + // Parse JSON-RPC response to check for errors + // The response might be in SSE format (event: message\ndata: {...}) + rpcResponse, err := parseJSONRPCResponseWithSSE(result.ResponseBody, result.StatusCode, "initialize response") + if err != nil { + return "", err + } + + if rpcResponse.Error != nil { + return "", fmt.Errorf("initialize error: code=%d, message=%s", rpcResponse.Error.Code, rpcResponse.Error.Message) + } + + return sessionID, nil +} + +// sendHTTPRequest sends a JSON-RPC request to an HTTP MCP server +// The ctx parameter is used to extract session ID for the Mcp-Session-Id header +func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params interface{}) (*Response, error) { + // Generate unique request ID using atomic counter + requestID := atomic.AddUint64(&requestIDCounter, 1) + + // For tools/call, ensure arguments field always exists (MCP protocol requirement) + if method == "tools/call" { + params = ensureToolCallArguments(params) + } + + logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID) + + // Execute HTTP request with custom header modification for session ID + result, err := c.executeHTTPRequest(ctx, method, params, requestID, func(httpReq *http.Request) { + // Add Mcp-Session-Id header with priority: + // 1) Context session ID (if explicitly provided for this request) + // 2) Stored httpSessionID from initialization + var sessionID string + if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" { + sessionID = ctxSessionID + logConn.Printf("Using session ID from context: %s", sessionID) + } else if c.httpSessionID != "" { + sessionID = c.httpSessionID + logConn.Printf("Using stored session ID from initialization: %s", sessionID) + } + + if sessionID != "" { + httpReq.Header.Set("Mcp-Session-Id", sessionID) + } else { + logConn.Printf("No session ID available (backend may not require session management)") + } + }) + if err != nil { + return nil, err + } + + logConn.Printf("Received HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody)) + + // Parse JSON-RPC response + // The response might be in SSE format (event: message\ndata: {...}) + rpcResponse, err := parseJSONRPCResponseWithSSE(result.ResponseBody, result.StatusCode, "JSON-RPC response") + if err != nil { + return nil, err + } + + // Check for HTTP errors after parsing + // If we have a non-OK status but successfully parsed a JSON-RPC response, + // pass it through (it may already contain an error field) + if result.StatusCode != http.StatusOK { + logConn.Printf("HTTP error status=%d with valid JSON-RPC response, passing through", result.StatusCode) + // If the response doesn't already have an error, construct one + if rpcResponse.Error == nil { + rpcResponse.Error = &ResponseError{ + Code: -32603, // Internal error + Message: fmt.Sprintf("HTTP %d: %s", result.StatusCode, http.StatusText(result.StatusCode)), + Data: result.ResponseBody, + } + } + } + + return rpcResponse, nil +} diff --git a/internal/server/auth.go b/internal/server/auth.go index eb72a4f2..22fc0cb9 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -46,3 +46,12 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc { next(w, r) } } + +// applyAuthIfConfigured applies authentication middleware if an API key is provided +// Returns the handler unchanged if apiKey is empty +func applyAuthIfConfigured(apiKey string, handler http.HandlerFunc) http.HandlerFunc { + if apiKey != "" { + return authMiddleware(apiKey, handler) + } + return handler +} diff --git a/internal/server/http_helpers.go b/internal/server/http_helpers.go index 129a8d54..9c9194af 100644 --- a/internal/server/http_helpers.go +++ b/internal/server/http_helpers.go @@ -15,6 +15,18 @@ import ( var logHelpers = logger.New("server:helpers") +// withResponseLogging wraps an http.Handler to log response bodies +func withResponseLogging(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lw := newResponseWriter(w) + handler.ServeHTTP(lw, r) + if len(lw.Body()) > 0 { + sanitizedBody := sanitize.SanitizeString(string(lw.Body())) + log.Printf("[%s] %s %s - Status: %d, Response: %s", r.RemoteAddr, r.Method, r.URL.Path, lw.StatusCode(), sanitizedBody) + } + }) +} + // extractAndValidateSession extracts the session ID from the Authorization header // and logs connection details. Returns empty string if validation fails. func extractAndValidateSession(r *http.Request) string { diff --git a/internal/server/transport.go b/internal/server/transport.go index 63c0640f..0a5875d1 100644 --- a/internal/server/transport.go +++ b/internal/server/transport.go @@ -44,27 +44,6 @@ func (t *HTTPTransport) Close() error { return nil } -// withResponseLogging wraps an http.Handler to log response bodies -func withResponseLogging(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lw := newResponseWriter(w) - handler.ServeHTTP(lw, r) - if len(lw.Body()) > 0 { - sanitizedBody := sanitize.SanitizeString(string(lw.Body())) - log.Printf("[%s] %s %s - Status: %d, Response: %s", r.RemoteAddr, r.Method, r.URL.Path, lw.StatusCode(), sanitizedBody) - } - }) -} - -// applyAuthIfConfigured applies authentication middleware if an API key is provided -// Returns the handler unchanged if apiKey is empty -func applyAuthIfConfigured(apiKey string, handler http.HandlerFunc) http.HandlerFunc { - if apiKey != "" { - return authMiddleware(apiKey, handler) - } - return handler -} - // CreateHTTPServerForMCP creates an HTTP server that handles MCP over streamable HTTP transport // If apiKey is provided, all requests except /health require authentication (spec 7.1) func CreateHTTPServerForMCP(addr string, unifiedServer *UnifiedServer, apiKey string) *http.Server {