Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions cmd/altinity-mcp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ func run(args []string) error {
Value: false,
Sources: cli.EnvVars("OAUTH_CLEAR_CLICKHOUSE_CREDENTIALS"),
},
&cli.StringFlag{
Name: "forward-http-headers",
Usage: "Comma-separated header name patterns forwarded from incoming requests to ClickHouse (supports * wildcard, e.g. X-*,X-Custom-Header)",
Value: "",
Sources: cli.EnvVars("FORWARD_HTTP_HEADERS"),
},
},
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
// Setup logging
Expand Down Expand Up @@ -525,23 +531,28 @@ func (a *application) startHTTPServer(cfg config.Config, mcpServer *mcp.Server)
}

// Create a middleware to inject the ClickHouseJWEServer into context
fwdPatterns := cfg.Server.ForwardHTTPHeaders
altinitymcp.WarnOnCatchAllPattern(fwdPatterns)
serverInjector := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), "clickhouse_jwe_server", a.mcpServer)
ctx = altinitymcp.ContextWithForwardedHeaders(ctx, r, fwdPatterns)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
serverInjectorOpenAPI := func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), "clickhouse_jwe_server", a.mcpServer)
ctx = altinitymcp.ContextWithForwardedHeaders(ctx, r, fwdPatterns)
a.mcpServer.OpenAPIHandler(w, r.WithContext(ctx))
}

// CORS handler
corsAllowHeaders := altinitymcp.CORSAllowHeaders(fwdPatterns)
corsHandler := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", cfg.Server.CORSOrigin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent")
w.Header().Set("Access-Control-Allow-Headers", corsAllowHeaders)
w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours

// Handle preflight requests
Expand Down Expand Up @@ -617,24 +628,29 @@ func (a *application) startSSEServer(cfg config.Config, mcpServer *mcp.Server) e
Msg("Starting MCP server with SSE transport")

// Create a middleware to inject the ClickHouseJWEServer into context
fwdPatterns := cfg.Server.ForwardHTTPHeaders
altinitymcp.WarnOnCatchAllPattern(fwdPatterns)
serverInjector := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Inject the ClickHouseJWEServer into the context
ctx := context.WithValue(r.Context(), "clickhouse_jwe_server", a.mcpServer)
ctx = altinitymcp.ContextWithForwardedHeaders(ctx, r, fwdPatterns)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
serverInjectorOpenAPI := func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), "clickhouse_jwe_server", a.mcpServer)
ctx = altinitymcp.ContextWithForwardedHeaders(ctx, r, fwdPatterns)
a.mcpServer.OpenAPIHandler(w, r.WithContext(ctx))
}

// CORS handler
corsAllowHeaders := altinitymcp.CORSAllowHeaders(fwdPatterns)
corsHandler := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", cfg.Server.CORSOrigin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent")
w.Header().Set("Access-Control-Allow-Headers", corsAllowHeaders)
w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours

// Handle preflight requests
Expand Down Expand Up @@ -985,6 +1001,21 @@ func overrideWithCLIFlags(cfg *config.Config, cmd CommandInterface) {
cfg.Server.CORSOrigin = "*"
}

// Override forward-http-headers with CLI flags
if cmd.IsSet("forward-http-headers") {
raw := cmd.String("forward-http-headers")
if raw != "" {
patterns := strings.Split(raw, ",")
for i := range patterns {
patterns[i] = strings.TrimSpace(patterns[i])
}
cfg.Server.ForwardHTTPHeaders = patterns
} else {
cfg.Server.ForwardHTTPHeaders = nil
}
}


// Override OAuth config with CLI flags
if cmd.IsSet("oauth-clear-clickhouse-credentials") {
cfg.Server.OAuth.ClearClickHouseCredentials = cmd.Bool("oauth-clear-clickhouse-credentials")
Expand Down
23 changes: 12 additions & 11 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,15 @@ type OAuthConfig struct {

// ServerConfig defines configuration for the MCP server
type ServerConfig struct {
Transport MCPTransport `json:"transport" yaml:"transport" flag:"transport" desc:"MCP transport type (stdio/http/sse)"`
Address string `json:"address" yaml:"address" flag:"address" desc:"Server address for HTTP/SSE transport"`
Port int `json:"port" yaml:"port" flag:"port" desc:"Server port for HTTP/SSE transport"`
TLS ServerTLSConfig `json:"tls" yaml:"tls"`
JWE JWEConfig `json:"jwe" yaml:"jwe"`
OAuth OAuthConfig `json:"oauth" yaml:"oauth"`
OpenAPI OpenAPIConfig `json:"openapi" yaml:"openapi" desc:"OpenAPI endpoints configuration"`
CORSOrigin string `json:"cors_origin" yaml:"cors_origin" flag:"cors-origin" desc:"CORS origin for HTTP/SSE transports (default: *)"`
Transport MCPTransport `json:"transport" yaml:"transport" flag:"transport" desc:"MCP transport type (stdio/http/sse)"`
Address string `json:"address" yaml:"address" flag:"address" desc:"Server address for HTTP/SSE transport"`
Port int `json:"port" yaml:"port" flag:"port" desc:"Server port for HTTP/SSE transport"`
TLS ServerTLSConfig `json:"tls" yaml:"tls"`
JWE JWEConfig `json:"jwe" yaml:"jwe"`
OAuth OAuthConfig `json:"oauth" yaml:"oauth"`
OpenAPI OpenAPIConfig `json:"openapi" yaml:"openapi" desc:"OpenAPI endpoints configuration"`
CORSOrigin string `json:"cors_origin" yaml:"cors_origin" flag:"cors-origin" desc:"CORS origin for HTTP/SSE transports (default: *)"`
ForwardHTTPHeaders []string `json:"forward_http_headers" yaml:"forward_http_headers" desc:"Header name patterns forwarded to ClickHouse (supports * wildcard)"`
// DynamicTools defines rules for generating tools from ClickHouse views
DynamicTools []DynamicToolRule `json:"dynamic_tools" yaml:"dynamic_tools"`
}
Expand All @@ -147,9 +148,9 @@ type OpenAPIConfig struct {

// DynamicToolRule describes a rule to create dynamic tools from views
type DynamicToolRule struct {
Name string `json:"name" yaml:"name"`
Regexp string `json:"regexp" yaml:"regexp"`
Prefix string `json:"prefix" yaml:"prefix"`
Name string `json:"name" yaml:"name"`
Regexp string `json:"regexp" yaml:"regexp"`
Prefix string `json:"prefix" yaml:"prefix"`
}

// LogLevel defines the logging level
Expand Down
187 changes: 166 additions & 21 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -159,39 +160,48 @@ func (s *ClickHouseJWEServer) AddPrompt(prompt *mcp.Prompt, handler PromptHandle
})
}

// GetClickHouseClient creates a ClickHouse client from JWE token or falls back to default config
// GetClickHouseClient creates a ClickHouse client from JWE token or falls back to default config.
// Also forwards any HTTP headers stored in context by the middleware.
func (s *ClickHouseJWEServer) GetClickHouseClient(ctx context.Context, tokenParam string) (*clickhouse.Client, error) {
return s.GetClickHouseClientWithHeaders(ctx, tokenParam, ForwardedHeadersFromContext(ctx))
}

// GetClickHouseClientWithHeaders creates a ClickHouse client, merging optional per-request
// HTTP headers (e.g. X-Tenant-Id) into the config before connecting to ClickHouse.
func (s *ClickHouseJWEServer) GetClickHouseClientWithHeaders(ctx context.Context, tokenParam string, extraHeaders map[string]string) (*clickhouse.Client, error) {
var chConfig config.ClickHouseConfig

if !s.Config.Server.JWE.Enabled {
// If JWE auth is disabled, use the default config
client, err := clickhouse.NewClient(ctx, s.Config.ClickHouse)
if err != nil {
return nil, fmt.Errorf("failed to create ClickHouse client: %w", err)
chConfig = s.Config.ClickHouse
} else {
if tokenParam == "" {
// JWE auth is enabled but no token provided
return nil, jwe_auth.ErrMissingToken
}
return client, nil
}

if tokenParam == "" {
// JWE auth is enabled but no token provided
return nil, jwe_auth.ErrMissingToken
}
// Parse and validate JWE token
claims, err := jwe_auth.ParseAndDecryptJWE(tokenParam, []byte(s.Config.Server.JWE.JWESecretKey), []byte(s.Config.Server.JWE.JWTSecretKey))
if err != nil {
log.Error().Err(err).Msg("failed to parse/decrypt JWE token")
return nil, err
}

// Parse and validate JWE token
claims, err := jwe_auth.ParseAndDecryptJWE(tokenParam, []byte(s.Config.Server.JWE.JWESecretKey), []byte(s.Config.Server.JWE.JWTSecretKey))
if err != nil {
log.Error().Err(err).Msg("failed to parse/decrypt JWE token")
return nil, err
var buildErr error
// Create ClickHouse config from JWE claims
chConfig, buildErr = s.buildConfigFromClaims(claims)
if buildErr != nil {
return nil, buildErr
}
}

// Create ClickHouse config from JWE claims
chConfig, err := s.buildConfigFromClaims(claims)
if err != nil {
return nil, err
if len(extraHeaders) > 0 {
chConfig.HttpHeaders = mergeHTTPHeaders(chConfig.HttpHeaders, extraHeaders)
}

// Create client with the configured parameters
client, err := clickhouse.NewClient(ctx, chConfig)
if err != nil {
return nil, fmt.Errorf("failed to create ClickHouse client from JWE: %w", err)
return nil, fmt.Errorf("failed to create ClickHouse client: %w", err)
}

return client, nil
Expand Down Expand Up @@ -1643,3 +1653,138 @@ func hasLimitClause(query string) bool {
hasLimit, _ := regexp.MatchString(`(?im)limit\s+\d+`, query)
return hasLimit
}

// contextKey avoids collisions with other packages using context.WithValue.
type contextKey string

const forwardedHeadersKey contextKey = "forwarded_http_headers"

// sensitiveHeaders are excluded from wildcard pattern matching to prevent
// accidental credential leakage. A user can still forward these by naming
// them explicitly (e.g. --forward-http-headers "Authorization").
var sensitiveHeaders = map[string]bool{
"Authorization": true,
"Cookie": true,
"Set-Cookie": true,
"Host": true,
"Proxy-Authorization": true,
}

// WarnOnCatchAllPattern logs a warning if any pattern is a bare "*",
// which would forward all non-sensitive headers to ClickHouse. Call
// once at startup after parsing the config.
func WarnOnCatchAllPattern(patterns []string) {
for _, p := range patterns {
if strings.TrimSpace(p) == "*" {
log.Warn().Msg("forward-http-headers contains \"*\": all headers (except Authorization, Cookie, Host, Set-Cookie, Proxy-Authorization) will be forwarded to ClickHouse; sensitive headers require an explicit pattern")
return
}
}
}

// ContextWithForwardedHeaders extracts headers matching the given patterns
// from the incoming HTTP request and stores them in context. This makes
// forwarded headers available to every handler path (OpenAPI, MCP JSON-RPC,
// dynamic tools) without coupling to *http.Request.
func ContextWithForwardedHeaders(ctx context.Context, r *http.Request, patterns []string) context.Context {
if headers := extractForwardHeaders(r, patterns); headers != nil {
return context.WithValue(ctx, forwardedHeadersKey, headers)
}
return ctx
}

// ForwardedHeadersFromContext retrieves forwarded HTTP headers previously
// stored by ContextWithForwardedHeaders. Returns nil when no headers are
// available (e.g. STDIO transport).
func ForwardedHeadersFromContext(ctx context.Context) map[string]string {
if headers, ok := ctx.Value(forwardedHeadersKey).(map[string]string); ok {
return headers
}
return nil
}

// extractForwardHeaders returns headers matching any of the given patterns.
// Patterns support trailing * wildcard (e.g. "X-*" matches all X-prefixed
// headers) and exact matches (e.g. "X-Tenant-Id"). Matching is
// case-insensitive. Sensitive headers (Authorization, Cookie, …) are
// excluded from wildcard matches but can be forwarded via an explicit
// exact-match pattern.
func extractForwardHeaders(r *http.Request, patterns []string) map[string]string {
if r == nil || len(patterns) == 0 {
return nil
}
headers := make(map[string]string)
for name := range r.Header {
canonical := http.CanonicalHeaderKey(name)
if matchesAnyPattern(canonical, patterns) {
headers[canonical] = r.Header.Get(name)
}
}
if len(headers) == 0 {
return nil
}
names := make([]string, 0, len(headers))
for k := range headers {
names = append(names, k)
}
sort.Strings(names)
log.Debug().Int("count", len(headers)).Strs("header_names", names).Msg("forwarding HTTP headers to ClickHouse")
return headers
}

// mergeHTTPHeaders merges extra per-request headers into a base header map,
// returning a new map without mutating either input.
func mergeHTTPHeaders(base, extra map[string]string) map[string]string {
merged := make(map[string]string, len(base)+len(extra))
for k, v := range base {
merged[k] = v
}
for k, v := range extra {
merged[k] = v
}
return merged
}

// CORSAllowHeaders builds the Access-Control-Allow-Headers value by combining
// a base set of standard headers with the configured forward patterns. Wildcard
// patterns (e.g. "X-*") are expanded to the CORS spec wildcard "*" since
// browsers don't support prefix wildcards in Access-Control-Allow-Headers.
func CORSAllowHeaders(forwardPatterns []string) string {
base := "Content-Type, Authorization, X-Altinity-MCP-Key, Mcp-Protocol-Version, Referer, User-Agent"
for _, p := range forwardPatterns {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if strings.HasSuffix(p, "*") {
return base + ", *"
}
base += ", " + p
}
return base
}

// matchesAnyPattern returns true if header matches at least one pattern.
// Supports trailing * wildcard (e.g. "X-*", "X-Tenant-*") and exact match.
// Comparison is case-insensitive. Wildcard patterns skip sensitive headers;
// only an explicit exact-match pattern can forward them.
func matchesAnyPattern(header string, patterns []string) bool {
lower := strings.ToLower(header)
for _, p := range patterns {
p = strings.ToLower(strings.TrimSpace(p))
if p == "" {
continue
}
if strings.HasSuffix(p, "*") {
if sensitiveHeaders[http.CanonicalHeaderKey(header)] {
continue
}
if strings.HasPrefix(lower, p[:len(p)-1]) {
return true
}
} else if lower == p {
return true
}
}
return false
}
Loading