diff --git a/cmd/altinity-mcp/main.go b/cmd/altinity-mcp/main.go index 083e6d2..1d736da 100644 --- a/cmd/altinity-mcp/main.go +++ b/cmd/altinity-mcp/main.go @@ -244,6 +244,12 @@ func run(args []string) error { Value: false, Sources: cli.EnvVars("OAUTH_CLEAR_CLICKHOUSE_CREDENTIALS"), }, + &cli.StringFlag{ + Name: "forward-http-headers", + Usage: "Comma-separated header name patterns forwarded from incoming requests to ClickHouse (supports * wildcard, e.g. X-*,X-Custom-Header)", + Value: "", + Sources: cli.EnvVars("FORWARD_HTTP_HEADERS"), + }, }, Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) { // Setup logging @@ -525,23 +531,28 @@ func (a *application) startHTTPServer(cfg config.Config, mcpServer *mcp.Server) } // Create a middleware to inject the ClickHouseJWEServer into context + fwdPatterns := cfg.Server.ForwardHTTPHeaders + altinitymcp.WarnOnCatchAllPattern(fwdPatterns) serverInjector := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), "clickhouse_jwe_server", a.mcpServer) + ctx = altinitymcp.ContextWithForwardedHeaders(ctx, r, fwdPatterns) next.ServeHTTP(w, r.WithContext(ctx)) }) } serverInjectorOpenAPI := func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), "clickhouse_jwe_server", a.mcpServer) + ctx = altinitymcp.ContextWithForwardedHeaders(ctx, r, fwdPatterns) a.mcpServer.OpenAPIHandler(w, r.WithContext(ctx)) } // CORS handler + corsAllowHeaders := altinitymcp.CORSAllowHeaders(fwdPatterns) corsHandler := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", cfg.Server.CORSOrigin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent") + w.Header().Set("Access-Control-Allow-Headers", corsAllowHeaders) w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours // Handle preflight requests @@ -617,24 +628,29 @@ func (a *application) startSSEServer(cfg config.Config, mcpServer *mcp.Server) e Msg("Starting MCP server with SSE transport") // Create a middleware to inject the ClickHouseJWEServer into context + fwdPatterns := cfg.Server.ForwardHTTPHeaders + altinitymcp.WarnOnCatchAllPattern(fwdPatterns) serverInjector := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Inject the ClickHouseJWEServer into the context ctx := context.WithValue(r.Context(), "clickhouse_jwe_server", a.mcpServer) + ctx = altinitymcp.ContextWithForwardedHeaders(ctx, r, fwdPatterns) next.ServeHTTP(w, r.WithContext(ctx)) }) } serverInjectorOpenAPI := func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), "clickhouse_jwe_server", a.mcpServer) + ctx = altinitymcp.ContextWithForwardedHeaders(ctx, r, fwdPatterns) a.mcpServer.OpenAPIHandler(w, r.WithContext(ctx)) } // CORS handler + corsAllowHeaders := altinitymcp.CORSAllowHeaders(fwdPatterns) corsHandler := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", cfg.Server.CORSOrigin) w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent") + w.Header().Set("Access-Control-Allow-Headers", corsAllowHeaders) w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours // Handle preflight requests @@ -985,6 +1001,21 @@ func overrideWithCLIFlags(cfg *config.Config, cmd CommandInterface) { cfg.Server.CORSOrigin = "*" } + // Override forward-http-headers with CLI flags + if cmd.IsSet("forward-http-headers") { + raw := cmd.String("forward-http-headers") + if raw != "" { + patterns := strings.Split(raw, ",") + for i := range patterns { + patterns[i] = strings.TrimSpace(patterns[i]) + } + cfg.Server.ForwardHTTPHeaders = patterns + } else { + cfg.Server.ForwardHTTPHeaders = nil + } + } + + // Override OAuth config with CLI flags if cmd.IsSet("oauth-clear-clickhouse-credentials") { cfg.Server.OAuth.ClearClickHouseCredentials = cmd.Bool("oauth-clear-clickhouse-credentials") diff --git a/pkg/config/config.go b/pkg/config/config.go index f5f6d8e..6d3cdf6 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -127,14 +127,15 @@ type OAuthConfig struct { // ServerConfig defines configuration for the MCP server type ServerConfig struct { - Transport MCPTransport `json:"transport" yaml:"transport" flag:"transport" desc:"MCP transport type (stdio/http/sse)"` - Address string `json:"address" yaml:"address" flag:"address" desc:"Server address for HTTP/SSE transport"` - Port int `json:"port" yaml:"port" flag:"port" desc:"Server port for HTTP/SSE transport"` - TLS ServerTLSConfig `json:"tls" yaml:"tls"` - JWE JWEConfig `json:"jwe" yaml:"jwe"` - OAuth OAuthConfig `json:"oauth" yaml:"oauth"` - OpenAPI OpenAPIConfig `json:"openapi" yaml:"openapi" desc:"OpenAPI endpoints configuration"` - CORSOrigin string `json:"cors_origin" yaml:"cors_origin" flag:"cors-origin" desc:"CORS origin for HTTP/SSE transports (default: *)"` + Transport MCPTransport `json:"transport" yaml:"transport" flag:"transport" desc:"MCP transport type (stdio/http/sse)"` + Address string `json:"address" yaml:"address" flag:"address" desc:"Server address for HTTP/SSE transport"` + Port int `json:"port" yaml:"port" flag:"port" desc:"Server port for HTTP/SSE transport"` + TLS ServerTLSConfig `json:"tls" yaml:"tls"` + JWE JWEConfig `json:"jwe" yaml:"jwe"` + OAuth OAuthConfig `json:"oauth" yaml:"oauth"` + OpenAPI OpenAPIConfig `json:"openapi" yaml:"openapi" desc:"OpenAPI endpoints configuration"` + CORSOrigin string `json:"cors_origin" yaml:"cors_origin" flag:"cors-origin" desc:"CORS origin for HTTP/SSE transports (default: *)"` + ForwardHTTPHeaders []string `json:"forward_http_headers" yaml:"forward_http_headers" desc:"Header name patterns forwarded to ClickHouse (supports * wildcard)"` // DynamicTools defines rules for generating tools from ClickHouse views DynamicTools []DynamicToolRule `json:"dynamic_tools" yaml:"dynamic_tools"` } @@ -147,9 +148,9 @@ type OpenAPIConfig struct { // DynamicToolRule describes a rule to create dynamic tools from views type DynamicToolRule struct { - Name string `json:"name" yaml:"name"` - Regexp string `json:"regexp" yaml:"regexp"` - Prefix string `json:"prefix" yaml:"prefix"` + Name string `json:"name" yaml:"name"` + Regexp string `json:"regexp" yaml:"regexp"` + Prefix string `json:"prefix" yaml:"prefix"` } // LogLevel defines the logging level diff --git a/pkg/server/server.go b/pkg/server/server.go index 1ce3ec1..421b872 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "regexp" + "sort" "strconv" "strings" "sync" @@ -159,39 +160,48 @@ func (s *ClickHouseJWEServer) AddPrompt(prompt *mcp.Prompt, handler PromptHandle }) } -// GetClickHouseClient creates a ClickHouse client from JWE token or falls back to default config +// GetClickHouseClient creates a ClickHouse client from JWE token or falls back to default config. +// Also forwards any HTTP headers stored in context by the middleware. func (s *ClickHouseJWEServer) GetClickHouseClient(ctx context.Context, tokenParam string) (*clickhouse.Client, error) { + return s.GetClickHouseClientWithHeaders(ctx, tokenParam, ForwardedHeadersFromContext(ctx)) +} + +// GetClickHouseClientWithHeaders creates a ClickHouse client, merging optional per-request +// HTTP headers (e.g. X-Tenant-Id) into the config before connecting to ClickHouse. +func (s *ClickHouseJWEServer) GetClickHouseClientWithHeaders(ctx context.Context, tokenParam string, extraHeaders map[string]string) (*clickhouse.Client, error) { + var chConfig config.ClickHouseConfig + if !s.Config.Server.JWE.Enabled { - // If JWE auth is disabled, use the default config - client, err := clickhouse.NewClient(ctx, s.Config.ClickHouse) - if err != nil { - return nil, fmt.Errorf("failed to create ClickHouse client: %w", err) + chConfig = s.Config.ClickHouse + } else { + if tokenParam == "" { + // JWE auth is enabled but no token provided + return nil, jwe_auth.ErrMissingToken } - return client, nil - } - if tokenParam == "" { - // JWE auth is enabled but no token provided - return nil, jwe_auth.ErrMissingToken - } + // Parse and validate JWE token + claims, err := jwe_auth.ParseAndDecryptJWE(tokenParam, []byte(s.Config.Server.JWE.JWESecretKey), []byte(s.Config.Server.JWE.JWTSecretKey)) + if err != nil { + log.Error().Err(err).Msg("failed to parse/decrypt JWE token") + return nil, err + } - // Parse and validate JWE token - claims, err := jwe_auth.ParseAndDecryptJWE(tokenParam, []byte(s.Config.Server.JWE.JWESecretKey), []byte(s.Config.Server.JWE.JWTSecretKey)) - if err != nil { - log.Error().Err(err).Msg("failed to parse/decrypt JWE token") - return nil, err + var buildErr error + // Create ClickHouse config from JWE claims + chConfig, buildErr = s.buildConfigFromClaims(claims) + if buildErr != nil { + return nil, buildErr + } } - // Create ClickHouse config from JWE claims - chConfig, err := s.buildConfigFromClaims(claims) - if err != nil { - return nil, err + if len(extraHeaders) > 0 { + chConfig.HttpHeaders = mergeHTTPHeaders(chConfig.HttpHeaders, extraHeaders) } // Create client with the configured parameters client, err := clickhouse.NewClient(ctx, chConfig) if err != nil { - return nil, fmt.Errorf("failed to create ClickHouse client from JWE: %w", err) + return nil, fmt.Errorf("failed to create ClickHouse client: %w", err) } return client, nil @@ -1643,3 +1653,138 @@ func hasLimitClause(query string) bool { hasLimit, _ := regexp.MatchString(`(?im)limit\s+\d+`, query) return hasLimit } + +// contextKey avoids collisions with other packages using context.WithValue. +type contextKey string + +const forwardedHeadersKey contextKey = "forwarded_http_headers" + +// sensitiveHeaders are excluded from wildcard pattern matching to prevent +// accidental credential leakage. A user can still forward these by naming +// them explicitly (e.g. --forward-http-headers "Authorization"). +var sensitiveHeaders = map[string]bool{ + "Authorization": true, + "Cookie": true, + "Set-Cookie": true, + "Host": true, + "Proxy-Authorization": true, +} + +// WarnOnCatchAllPattern logs a warning if any pattern is a bare "*", +// which would forward all non-sensitive headers to ClickHouse. Call +// once at startup after parsing the config. +func WarnOnCatchAllPattern(patterns []string) { + for _, p := range patterns { + if strings.TrimSpace(p) == "*" { + log.Warn().Msg("forward-http-headers contains \"*\": all headers (except Authorization, Cookie, Host, Set-Cookie, Proxy-Authorization) will be forwarded to ClickHouse; sensitive headers require an explicit pattern") + return + } + } +} + +// ContextWithForwardedHeaders extracts headers matching the given patterns +// from the incoming HTTP request and stores them in context. This makes +// forwarded headers available to every handler path (OpenAPI, MCP JSON-RPC, +// dynamic tools) without coupling to *http.Request. +func ContextWithForwardedHeaders(ctx context.Context, r *http.Request, patterns []string) context.Context { + if headers := extractForwardHeaders(r, patterns); headers != nil { + return context.WithValue(ctx, forwardedHeadersKey, headers) + } + return ctx +} + +// ForwardedHeadersFromContext retrieves forwarded HTTP headers previously +// stored by ContextWithForwardedHeaders. Returns nil when no headers are +// available (e.g. STDIO transport). +func ForwardedHeadersFromContext(ctx context.Context) map[string]string { + if headers, ok := ctx.Value(forwardedHeadersKey).(map[string]string); ok { + return headers + } + return nil +} + +// extractForwardHeaders returns headers matching any of the given patterns. +// Patterns support trailing * wildcard (e.g. "X-*" matches all X-prefixed +// headers) and exact matches (e.g. "X-Tenant-Id"). Matching is +// case-insensitive. Sensitive headers (Authorization, Cookie, …) are +// excluded from wildcard matches but can be forwarded via an explicit +// exact-match pattern. +func extractForwardHeaders(r *http.Request, patterns []string) map[string]string { + if r == nil || len(patterns) == 0 { + return nil + } + headers := make(map[string]string) + for name := range r.Header { + canonical := http.CanonicalHeaderKey(name) + if matchesAnyPattern(canonical, patterns) { + headers[canonical] = r.Header.Get(name) + } + } + if len(headers) == 0 { + return nil + } + names := make([]string, 0, len(headers)) + for k := range headers { + names = append(names, k) + } + sort.Strings(names) + log.Debug().Int("count", len(headers)).Strs("header_names", names).Msg("forwarding HTTP headers to ClickHouse") + return headers +} + +// mergeHTTPHeaders merges extra per-request headers into a base header map, +// returning a new map without mutating either input. +func mergeHTTPHeaders(base, extra map[string]string) map[string]string { + merged := make(map[string]string, len(base)+len(extra)) + for k, v := range base { + merged[k] = v + } + for k, v := range extra { + merged[k] = v + } + return merged +} + +// CORSAllowHeaders builds the Access-Control-Allow-Headers value by combining +// a base set of standard headers with the configured forward patterns. Wildcard +// patterns (e.g. "X-*") are expanded to the CORS spec wildcard "*" since +// browsers don't support prefix wildcards in Access-Control-Allow-Headers. +func CORSAllowHeaders(forwardPatterns []string) string { + base := "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent" + for _, p := range forwardPatterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if strings.HasSuffix(p, "*") { + return base + ", *" + } + base += ", " + p + } + return base +} + +// matchesAnyPattern returns true if header matches at least one pattern. +// Supports trailing * wildcard (e.g. "X-*", "X-Tenant-*") and exact match. +// Comparison is case-insensitive. Wildcard patterns skip sensitive headers; +// only an explicit exact-match pattern can forward them. +func matchesAnyPattern(header string, patterns []string) bool { + lower := strings.ToLower(header) + for _, p := range patterns { + p = strings.ToLower(strings.TrimSpace(p)) + if p == "" { + continue + } + if strings.HasSuffix(p, "*") { + if sensitiveHeaders[http.CanonicalHeaderKey(header)] { + continue + } + if strings.HasPrefix(lower, p[:len(p)-1]) { + return true + } + } else if lower == p { + return true + } + } + return false +} diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index ca8578e..d632c42 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -1983,6 +1983,106 @@ func TestMakeDynamicToolHandler_WithParams(t *testing.T) { require.False(t, result.IsError) } +func TestExtractForwardHeaders(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Custom-Header", "value_a") + req.Header.Set("X-Request-Id", "abc-123") + req.Header.Set("Authorization", "Bearer secret") + req.Header.Set("Cookie", "session=abc") + + t.Run("wildcard pattern forwards matching, excludes non-matching", func(t *testing.T) { + headers := extractForwardHeaders(req, []string{"X-*"}) + require.Len(t, headers, 2) + require.Equal(t, "value_a", headers["X-Custom-Header"]) + require.Equal(t, "abc-123", headers["X-Request-Id"]) + }) + + t.Run("exact pattern restricts to named header only", func(t *testing.T) { + headers := extractForwardHeaders(req, []string{"X-Custom-Header"}) + require.Len(t, headers, 1) + require.Equal(t, "value_a", headers["X-Custom-Header"]) + }) + + t.Run("empty patterns forwards nothing", func(t *testing.T) { + require.Nil(t, extractForwardHeaders(req, nil)) + }) + + t.Run("wildcard excludes sensitive headers", func(t *testing.T) { + headers := extractForwardHeaders(req, []string{"*"}) + require.NotNil(t, headers) + require.Equal(t, "value_a", headers["X-Custom-Header"]) + require.Equal(t, "abc-123", headers["X-Request-Id"]) + require.Empty(t, headers["Authorization"], "Authorization must be blocked by wildcard") + require.Empty(t, headers["Cookie"], "Cookie must be blocked by wildcard") + }) + + t.Run("explicit pattern forwards sensitive header", func(t *testing.T) { + headers := extractForwardHeaders(req, []string{"Authorization"}) + require.Len(t, headers, 1) + require.Equal(t, "Bearer secret", headers["Authorization"]) + }) +} + +func TestContextForwardedHeaders_RoundTrip(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Y-Custom-Header", "value_a") + req.Header.Set("X-Request-Id", "req-42") + req.Header.Set("Authorization", "Bearer secret") + + ctx := ContextWithForwardedHeaders(context.Background(), req, []string{"*"}) + headers := ForwardedHeadersFromContext(ctx) + + require.Equal(t, "value_a", headers["Y-Custom-Header"]) + require.Equal(t, "req-42", headers["X-Request-Id"]) + require.Empty(t, headers["Authorization"], "wildcard must not forward sensitive headers") + require.Nil(t, ForwardedHeadersFromContext(context.Background())) +} + +func TestCORSAllowHeaders(t *testing.T) { + cases := []struct { + name string + input []string + expected string + }{ + {"empty", nil, "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent"}, + {"single", []string{"X-Custom-Header"}, "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent, X-Custom-Header"}, + {"multiple", []string{"X-Custom-Header", "X-Other"}, "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent, X-Custom-Header, X-Other"}, + {"wildcard", []string{"X-*"}, "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent, *"}, + {"mixed", []string{"X-Custom-Header", "X-*"}, "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent, X-Custom-Header, *"}, + {"spaces", []string{" X-Custom-Header "}, "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent, X-Custom-Header"}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + actual := CORSAllowHeaders(c.input) + require.Equal(t, c.expected, actual) + }) + } +} + +// TestMergeHTTPHeaders verifies that mergeHTTPHeaders produces a correct union +// where extra values override base values, and neither input map is mutated. +func TestMergeHTTPHeaders(t *testing.T) { + base := map[string]string{"X-Base": "base", "X-Shared": "from-base"} + extra := map[string]string{"X-Extra": "extra", "X-Shared": "from-extra"} + + merged := mergeHTTPHeaders(base, extra) + + require.Equal(t, "base", merged["X-Base"]) + require.Equal(t, "extra", merged["X-Extra"]) + require.Equal(t, "from-extra", merged["X-Shared"]) + + require.Equal(t, "from-base", base["X-Shared"], "base map must not be mutated") + require.Empty(t, base["X-Extra"], "base map must not be mutated") +} + +// TestMergeHTTPHeaders_NilBase verifies merging into a nil base map works. +func TestMergeHTTPHeaders_NilBase(t *testing.T) { + extra := map[string]string{"X-Extra": "extra"} + merged := mergeHTTPHeaders(nil, extra) + require.Equal(t, "extra", merged["X-Extra"]) + require.Len(t, merged, 1) +} + // Unused import suppressors (remove if unused) var _ = io.EOF var _ = fmt.Sprintf