diff --git a/cmd/mcp/flags.go b/cmd/mcp/flags.go index ffe0ecd..7b7eaf9 100644 --- a/cmd/mcp/flags.go +++ b/cmd/mcp/flags.go @@ -73,7 +73,7 @@ var ServeFlags = []FlagDef{ ViperKey: "mcp.allow_api_key_query_param", Default: "false", IsBool: true, - Usage: "Accept the Seerr API key via the api_key query parameter in addition to the X-Api-Key header (HTTP transport only)", + Usage: "Accept the MCP auth token via the api_key query parameter in addition to headers (HTTP transport only)", }, { Name: "log-file", diff --git a/cmd/mcp/serve.go b/cmd/mcp/serve.go index 685f08a..aaf9451 100644 --- a/cmd/mcp/serve.go +++ b/cmd/mcp/serve.go @@ -159,11 +159,9 @@ func runServe(_ *cobra.Command, args []string) error { httpHandler := server.NewStreamableHTTPServer(s) handler := http.Handler(httpHandler) - // Per-request Seerr API key injection (header or optional query param). - handler = SeerrAPIKeyMiddleware(allowAPIKeyQueryParam, handler) handler = httpLoggingMiddleware(handler) if authToken != "" { - handler = bearerAuthMiddleware(authToken, handler) + handler = MCPAuthMiddleware(authToken, allowAPIKeyQueryParam, handler) } // The health endpoint must be reachable without auth, so register it in // a top-level mux that sits above the bearer-auth middleware. @@ -243,50 +241,31 @@ func corsMiddleware(next http.Handler) http.Handler { }) } -func bearerAuthMiddleware(token string, next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authHeader := r.Header.Get("Authorization") - const prefix = "Bearer " - if !strings.HasPrefix(authHeader, prefix) { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - provided := strings.TrimPrefix(authHeader, prefix) - if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - next.ServeHTTP(w, r) - }) -} - -// SeerrAPIKeyMiddleware extracts the Seerr API key from the incoming request -// and injects it into the request context for use by MCP tool handlers. +// MCPAuthMiddleware authenticates incoming MCP HTTP requests against token. // -// The key is read from the X-Api-Key request header first. When -// allowQueryParam is true the middleware also accepts the key via the -// api_key query parameter; the header takes precedence when both are present. +// Accepted credential locations (in precedence order): +// 1. Authorization: Bearer +// 2. X-Api-Key: +// 3. ?api_key= query parameter — only when allowQueryParam is true. // -// If neither location provides a key the middleware responds with 401. -func SeerrAPIKeyMiddleware(allowQueryParam bool, next http.Handler) http.Handler { +// Requests that do not supply a matching credential are rejected with 401. +// The Seerr API key used for outbound Seerr calls is always read from the +// application config (seerr.api_key) and is never sourced from this middleware. +func MCPAuthMiddleware(token string, allowQueryParam bool, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var apiKey string - if v := r.Header.Get("X-Api-Key"); v != "" { - apiKey = v + var provided string + if v := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer "); v != r.Header.Get("Authorization") { + provided = v + } else if v := r.Header.Get("X-Api-Key"); v != "" { + provided = v } else if allowQueryParam { - if v := r.URL.Query().Get("api_key"); v != "" { - apiKey = v - } + provided = r.URL.Query().Get("api_key") } - if apiKey != "" { - ctx := context.WithValue(r.Context(), apiKeyCtxKey, apiKey) - r = r.Clone(ctx) - } else { + if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } - next.ServeHTTP(w, r) }) } diff --git a/tests/mcp_api_key_middleware_test.go b/tests/mcp_api_key_middleware_test.go index d5bf279..8124fb9 100644 --- a/tests/mcp_api_key_middleware_test.go +++ b/tests/mcp_api_key_middleware_test.go @@ -10,100 +10,119 @@ import ( "github.com/stretchr/testify/assert" ) -// captureContextKey is a helper that records the API key injected into the -// request context by the middleware. -func captureContextKey(t *testing.T) (handler http.Handler, captured *string) { - t.Helper() - s := "" - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - s, _ = r.Context().Value(cmdmcp.APIKeyContextKey).(string) +const testAuthToken = "super-secret" + +// newAuthTestHandler returns an inner handler that records whether it was called. +func newAuthTestHandler(called *bool) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + *called = true w.WriteHeader(http.StatusOK) }) - return h, &s } -func TestSeerrAPIKeyMiddleware_headerOnly(t *testing.T) { - inner, captured := captureContextKey(t) - handler := cmdmcp.SeerrAPIKeyMiddleware(false, inner) +func TestMCPAuthMiddleware_bearerToken(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, false, newAuthTestHandler(&called)) - req := httptest.NewRequest(http.MethodGet, "/mcp", nil) - req.Header.Set("X-Api-Key", "header-key") + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Authorization", "Bearer "+testAuthToken) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "header-key", *captured) + assert.True(t, called) } -func TestSeerrAPIKeyMiddleware_queryParamOnly(t *testing.T) { - inner, captured := captureContextKey(t) - handler := cmdmcp.SeerrAPIKeyMiddleware(true, inner) +func TestMCPAuthMiddleware_xApiKeyHeader(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, false, newAuthTestHandler(&called)) - req := httptest.NewRequest(http.MethodGet, "/mcp?api_key=qparam-key", nil) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("X-Api-Key", testAuthToken) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "qparam-key", *captured) + assert.True(t, called) } -func TestSeerrAPIKeyMiddleware_headerPrecedenceOverQueryParam(t *testing.T) { - inner, captured := captureContextKey(t) - handler := cmdmcp.SeerrAPIKeyMiddleware(true, inner) +func TestMCPAuthMiddleware_queryParam(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, true, newAuthTestHandler(&called)) - req := httptest.NewRequest(http.MethodGet, "/mcp?api_key=qparam-key", nil) - req.Header.Set("X-Api-Key", "header-key") + req := httptest.NewRequest(http.MethodPost, "/mcp?api_key="+testAuthToken, nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "header-key", *captured, "header must take precedence over query param") + assert.True(t, called) } -func TestSeerrAPIKeyMiddleware_queryParamDisabled_ignoresQueryParam(t *testing.T) { - var innerCalled bool - inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - innerCalled = true - w.WriteHeader(http.StatusOK) - }) - // allowQueryParam=false; query param present but should not satisfy auth. - handler := cmdmcp.SeerrAPIKeyMiddleware(false, inner) +func TestMCPAuthMiddleware_queryParamDisabled_returns401(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, false, newAuthTestHandler(&called)) - req := httptest.NewRequest(http.MethodGet, "/mcp?api_key=qparam-key", nil) + req := httptest.NewRequest(http.MethodPost, "/mcp?api_key="+testAuthToken, nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) - // No header, query param disabled — must return 401. assert.Equal(t, http.StatusUnauthorized, rec.Code) - assert.False(t, innerCalled) + assert.False(t, called) } -func TestSeerrAPIKeyMiddleware_neitherPresent_returns401(t *testing.T) { - var innerCalled bool - inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - innerCalled = true - w.WriteHeader(http.StatusOK) - }) - handler := cmdmcp.SeerrAPIKeyMiddleware(true, inner) +func TestMCPAuthMiddleware_wrongToken_returns401(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, true, newAuthTestHandler(&called)) - req := httptest.NewRequest(http.MethodGet, "/mcp", nil) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("X-Api-Key", "wrong-token") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) assert.Equal(t, http.StatusUnauthorized, rec.Code) - assert.False(t, innerCalled) + assert.False(t, called) } -func TestSeerrAPIKeyMiddleware_queryParam_sensitiveValueNotLogged(t *testing.T) { - // This test ensures SafeLogQuery redacts the api_key value from query strings. +func TestMCPAuthMiddleware_noCredentials_returns401(t *testing.T) { + var called bool + handler := cmdmcp.MCPAuthMiddleware(testAuthToken, true, newAuthTestHandler(&called)) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code) + assert.False(t, called) +} + +// TestMCPAuthMiddleware_noAuthToken verifies that when no --auth-token is +// configured the middleware is not applied and requests pass through without +// any credential check. This is the no-auth / --no-auth scenario. +func TestMCPAuthMiddleware_noAuthToken_requestPassesThrough(t *testing.T) { + // When authToken is empty the serve command does not wrap the handler with + // MCPAuthMiddleware at all. The Seerr API key is always read from the app + // config (seerr.api_key) and never sourced from the incoming request. + var called bool + inner := newAuthTestHandler(&called) + + // Simulate the no-auth path: handler is NOT wrapped with MCPAuthMiddleware. + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + rec := httptest.NewRecorder() + inner.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.True(t, called) +} + +func TestMCPAuthMiddlewareIsTheOnlyPerRequestAuthMechanism(t *testing.T) { + // Compile-time check that MCPAuthMiddleware exists and has the expected signature. + _ = cmdmcp.MCPAuthMiddleware +} + +func TestMCPLogQueryRedaction(t *testing.T) { + // Ensure SafeLogQuery redacts the api_key value from query strings. redacted := cmdmcp.SafeLogQuery("api_key=secret123&page=1") assert.NotContains(t, redacted, "secret123") assert.Contains(t, redacted, "api_key={redacted}") assert.Contains(t, redacted, "page=1") } - -func TestSeerrAPIKeyMiddlewareIsTheOnlyPerRequestKeyMechanism(t *testing.T) { - // SeerrAPIKeyMiddleware must compile and be the sole per-request API key - // injection mechanism — path-based routing has been removed. - _ = cmdmcp.SeerrAPIKeyMiddleware -} diff --git a/tests/mcp_serve_test.go b/tests/mcp_serve_test.go index 6f1d23a..317cd76 100644 --- a/tests/mcp_serve_test.go +++ b/tests/mcp_serve_test.go @@ -416,8 +416,7 @@ func TestMCPBlocklistListHandler(t *testing.T) { func TestAPIKeyContextPropagation(t *testing.T) { // Verify that an API key injected into the context is forwarded to the - // Seerr API as the X-Api-Key header, matching how SeerrAPIKeyMiddleware - // injects keys for downstream tool handlers. + // Seerr API as the X-Api-Key header by tool handlers. var receivedAPIKey string ts, cleanup := newMCPTestServer(func(w http.ResponseWriter, r *http.Request) {