Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package oauth

import (
"fmt"
"strconv"
"strings"

"github.com/golang-jwt/jwt/v5"
"github.com/tuannvm/oauth-mcp-proxy/provider"
)

Expand All @@ -18,6 +21,7 @@ type Config struct {
Audience string
ClientID string
ClientSecret string
Scopes []string

// Server configuration
ServerURL string // Full URL of the MCP server
Expand All @@ -31,6 +35,12 @@ type Config struct {
// Implement the Logger interface (Debug, Info, Warn, Error methods) to
// integrate with your application's logging system (e.g., zap, logrus).
Logger Logger

// Token validation configuration
SkipIssuerCheck bool
SkipAudienceCheck bool
SkipExpiryCheck bool
TokenValidators []func(claims jwt.MapClaims) error
}

// Validate validates the configuration
Expand Down Expand Up @@ -119,11 +129,15 @@ func SetupOAuth(cfg *Config) (provider.TokenValidator, error) {
func createValidator(cfg *Config, logger Logger) (provider.TokenValidator, error) {
// Convert root Config to provider.Config
providerCfg := &provider.Config{
Provider: cfg.Provider,
Issuer: cfg.Issuer,
Audience: cfg.Audience,
JWTSecret: cfg.JWTSecret,
Logger: logger,
Provider: cfg.Provider,
Issuer: cfg.Issuer,
Audience: cfg.Audience,
JWTSecret: cfg.JWTSecret,
Logger: logger,
SkipIssuerCheck: cfg.SkipIssuerCheck,
SkipAudienceCheck: cfg.SkipAudienceCheck,
SkipExpiryCheck: cfg.SkipExpiryCheck,
TokenValidators: cfg.TokenValidators,
}

var validator provider.TokenValidator
Expand Down Expand Up @@ -217,12 +231,36 @@ func (b *ConfigBuilder) WithJWTSecret(secret []byte) *ConfigBuilder {
return b
}

// WithScopes sets the OIDC scopes
func (b *ConfigBuilder) WithScopes(scopes []string) *ConfigBuilder {
b.config.Scopes = scopes
return b
}

// WithLogger sets the logger
func (b *ConfigBuilder) WithLogger(logger Logger) *ConfigBuilder {
b.config.Logger = logger
return b
}

// WithSkipIssuerCheck sets issuer check toggle
func (b *ConfigBuilder) WithSkipIssuerCheck(skipIssuerCheck bool) *ConfigBuilder {
b.config.SkipIssuerCheck = skipIssuerCheck
return b
}

// WithSkipAudienceCheck sets audience check toggle
func (b *ConfigBuilder) WithSkipAudienceCheck(skipAudienceCheck bool) *ConfigBuilder {
b.config.SkipAudienceCheck = skipAudienceCheck
return b
}

// WithSkipExpiryCheck sets expiry check toggle
func (b *ConfigBuilder) WithSkipExpiryCheck(skipExpiryCheck bool) *ConfigBuilder {
b.config.SkipExpiryCheck = skipExpiryCheck
return b
}

// WithServerURL sets the full server URL directly
func (b *ConfigBuilder) WithServerURL(url string) *ConfigBuilder {
b.config.ServerURL = url
Expand Down Expand Up @@ -281,6 +319,12 @@ func FromEnv() (*Config, error) {

jwtSecret := getEnv("JWT_SECRET", "")

scopes := []string{}
scopesEnv := getEnv("OIDC_SCOPES", "")
if scopesEnv != "" {
scopes = strings.Split(scopesEnv, " ")
}

return NewConfigBuilder().
WithMode(getEnv("OAUTH_MODE", "")).
WithProvider(getEnv("OAUTH_PROVIDER", "")).
Expand All @@ -289,7 +333,24 @@ func FromEnv() (*Config, error) {
WithAudience(getEnv("OIDC_AUDIENCE", "")).
WithClientID(getEnv("OIDC_CLIENT_ID", "")).
WithClientSecret(getEnv("OIDC_CLIENT_SECRET", "")).
WithScopes(scopes).
WithSkipAudienceCheck(parseBoolEnv("OIDC_SKIP_AUDIENCE_CHECK", false)).
WithSkipIssuerCheck(parseBoolEnv("OIDC_SKIP_ISSUER_CHECK", false)).
WithSkipExpiryCheck(parseBoolEnv("OIDC_SKIP_EXPIRY_CHECK", false)).
WithServerURL(serverURL).
WithJWTSecret([]byte(jwtSecret)).
Build()
}

// parseBoolEnv parses a boolean environment variable
func parseBoolEnv(key string, defaultVal bool) bool {
val := getEnv(key, "")
if val == "" {
return defaultVal
}
parsed, err := strconv.ParseBool(val)
if err != nil {
return defaultVal
}
return parsed
}
9 changes: 8 additions & 1 deletion handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type OAuth2Config struct {
Audience string
ClientID string
ClientSecret string
Scopes []string

// Server configuration
MCPHost string
Expand Down Expand Up @@ -96,7 +97,7 @@ func NewOAuth2Handler(cfg *OAuth2Config, logger Logger) *OAuth2Handler {
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: endpoint,
Scopes: []string{"openid", "profile", "email"},
Scopes: cfg.Scopes,
}

// Log client configuration type for debugging
Expand Down Expand Up @@ -177,6 +178,11 @@ func NewOAuth2ConfigFromConfig(cfg *Config, version string) *OAuth2Config {
mcpURL = getEnv("MCP_URL", fmt.Sprintf("%s://%s:%s", scheme, mcpHost, mcpPort))
}

scopes := cfg.Scopes
if len(scopes) == 0 {
scopes = []string{"openid", "profile", "email"}
}

return &OAuth2Config{
Enabled: true,
Mode: cfg.Mode,
Expand All @@ -186,6 +192,7 @@ func NewOAuth2ConfigFromConfig(cfg *Config, version string) *OAuth2Config {
Audience: cfg.Audience,
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Scopes: scopes,
MCPHost: mcpHost,
MCPPort: mcpPort,
MCPURL: mcpURL,
Expand Down
6 changes: 3 additions & 3 deletions metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (h *OAuth2Handler) HandleProtectedResourceMetadata(w http.ResponseWriter, r
"resource_documentation": fmt.Sprintf("%s/docs", h.config.MCPURL),
"resource_policy_uri": fmt.Sprintf("%s/policy", h.config.MCPURL),
"resource_tos_uri": fmt.Sprintf("%s/tos", h.config.MCPURL),
"scopes_supported": []string{"openid", "profile", "email"},
"scopes_supported": h.config.Scopes,
}

// Encode and send response
Expand Down Expand Up @@ -243,7 +243,7 @@ func (h *OAuth2Handler) HandleOIDCDiscovery(w http.ResponseWriter, r *http.Reque
"token_endpoint_auth_methods_supported": []string{"none"},
"code_challenge_methods_supported": []string{"plain", "S256"},
"subject_types_supported": []string{"public"},
"scopes_supported": []string{"openid", "profile", "email"},
"scopes_supported": h.config.Scopes,
}

// Add provider-specific fields
Expand Down Expand Up @@ -283,7 +283,7 @@ func (h *OAuth2Handler) GetAuthorizationServerMetadata() map[string]interface{}
"grant_types_supported": []string{"authorization_code"},
"token_endpoint_auth_methods_supported": []string{"none"},
"code_challenge_methods_supported": []string{"plain", "S256"},
"scopes_supported": []string{"openid", "profile", "email"},
"scopes_supported": h.config.Scopes,
}

// Add provider-specific endpoints
Expand Down
43 changes: 27 additions & 16 deletions provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ type Logger interface {

// Config holds OAuth configuration (subset needed by provider)
type Config struct {
Provider string
Issuer string
Audience string
JWTSecret []byte
Logger Logger
Provider string
Issuer string
Audience string
JWTSecret []byte
Logger Logger
SkipIssuerCheck bool
SkipAudienceCheck bool
SkipExpiryCheck bool
TokenValidators []func(claims jwt.MapClaims) error
}

// TokenValidator interface for OAuth token validation
Expand All @@ -52,10 +56,11 @@ type HMACValidator struct {

// OIDCValidator validates JWT tokens using OIDC/JWKS (Okta, Google, Azure)
type OIDCValidator struct {
verifier *oidc.IDTokenVerifier
provider *oidc.Provider
audience string
logger Logger
verifier *oidc.IDTokenVerifier
provider *oidc.Provider
audience string
TokenValidators []func(claims jwt.MapClaims) error
logger Logger
}

// Initialize sets up the HMAC validator with JWT secret and audience
Expand Down Expand Up @@ -90,7 +95,6 @@ func (v *HMACValidator) ValidateToken(ctx context.Context, tokenString string) (
}
return []byte(v.secret), nil
})

if err != nil {
return nil, fmt.Errorf("failed to parse and validate token: %w", err)
}
Expand Down Expand Up @@ -204,15 +208,19 @@ func (v *OIDCValidator) Initialize(cfg *Config) error {
verifier := provider.Verifier(&oidc.Config{
ClientID: cfg.Audience, // Note: go-oidc uses ClientID field for audience validation - see https://github.com/coreos/go-oidc/blob/v3/oidc/verify.go#L85
SupportedSigningAlgs: []string{oidc.RS256, oidc.ES256},
SkipClientIDCheck: false, // Always validate if ClientID is provided
SkipExpiryCheck: false, // Verify expiration
SkipIssuerCheck: false, // Verify issuer
SkipClientIDCheck: cfg.SkipAudienceCheck,
SkipExpiryCheck: cfg.SkipExpiryCheck,
SkipIssuerCheck: cfg.SkipIssuerCheck,
})

v.logger.Info("OAuth: OIDC validator initialized with audience validation: %s", cfg.Audience)

v.provider = provider
v.verifier = verifier
v.TokenValidators = cfg.TokenValidators
if !cfg.SkipAudienceCheck {
v.TokenValidators = append(v.TokenValidators, v.validateAudience)
}
return nil
}

Expand Down Expand Up @@ -256,9 +264,12 @@ func (v *OIDCValidator) ValidateToken(ctx context.Context, tokenString string) (
return nil, fmt.Errorf("failed to extract raw claims: %w", err)
}

// Validate audience claim for security (explicit check)
if err := v.validateAudience(rawClaims); err != nil {
return nil, fmt.Errorf("audience validation failed: %w", err)
// Run extra validation functions
for i, fn := range v.TokenValidators {
err := fn(rawClaims)
if err != nil {
return nil, fmt.Errorf("validation function %d failed with error: %w", i, err)
}
}

return &User{
Expand Down