From 2610744569445d2f2f99cf1ab50b63ea5435ebf2 Mon Sep 17 00:00:00 2001 From: alvarofraguas Date: Thu, 14 May 2026 19:27:33 +0200 Subject: [PATCH] =?UTF-8?q?osctrl-api:=20security=20hardening=20=E2=80=94?= =?UTF-8?q?=20auth=20bedrock,=20env=20secret=20containment,=20shared=20rat?= =?UTF-8?q?e-limit=20+=20audit-log=20infra?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Server-side hardening for osctrl-api, plus shared infrastructure (rate-limit package, audit-log helpers, trusted-proxies plumbing) that osctrl-tls also consumes — its consumer-side changes ship in a companion PR so the TLS-facing surface can be tested in isolation. == Auth bedrock == cmd/api: - --auth=jwt is now the default. Refuse to start with --auth=none unless OSCTRL_INSECURE_NO_AUTH=1 is set. When opted in, a 60s warning ticker keeps the deployment from drifting into 'auth-off forever'. - HttpOnly + Secure cookie session for SPA-style clients (osctrl_token). CLI clients with Authorization: Bearer continue to work unchanged. - Double-submit CSRF (osctrl_csrf cookie + X-CSRF-Token header) for mutating cookie-authenticated requests. CLI Bearer flows exempt. - JWT signing-algorithm pin (HMAC only) to defeat alg-confusion attacks (alg:none / RS256-with-HS256-verify). - JWT secret minimum 32 bytes (HS256 needs HMAC key ≥ hash output). Startup fails fast with the openssl one-liner if too short. - Strict 'forwarded headers' trust via --trusted-proxies. Empty default means utils.GetIP ignores X-Forwarded-For / X-Real-IP — an internet attacker can't spoof IPs to defeat rate-limits or poison audit logs. == Env secret containment + cross-env defense == pkg/types: new TLSEnvironmentView — the low-privilege env projection. Omits Secret, EnrollSecretPath, RemoveSecretPath, Certificate, Flags, and every other field that materially contributes to enrolling a node. cmd/api/handlers/environments.go: - EnvironmentHandler now branches on access level: AdminLevel (or super-admin) gets the full storage struct; UserLevel gets the low-priv view. - EnvEnrollHandler / EnvRemoveHandler raised from UserLevel to AdminLevel — both embed the env's enroll/remove secret. - Both handlers log only the target name, not returnData. - EnvActionsHandler 'create' branch validates caller-supplied UUID via EnvUUIDFilter (rejects malformed) and EnvExists (rejects collision). 'delete' branch gets the same validation for symmetry. cmd/api/handlers/queries.go: QueryResultsHandler now precheck-validates the named query belongs to env.ID via h.Queries.Exists(name, env.ID) and returns 404 otherwise. logging.GetQueryResults filtered on 'name' only, so without this gate a user with QueryLevel on env A could pull results from env B by passing B's query name in A's URL. pkg/environments/environments.go: tighten EnvUUIDFilter regex and add axis-pure Exists/UUIDExists helpers so handler checks can match the router's expectations exactly. == Shared rate-limit + audit-log infrastructure == pkg/ratelimit (new): per-key token-bucket rate limiter with idle eviction. Used by osctrl-api for /login here, and by osctrl-tls for /enroll in the companion PR. Tunable burst, window, and key function (KeyByIP today; KeyByIPAndEnv available). pkg/auditlog/audit.go: FailedLogin + FailedEnroll helpers — a clean stream of authn/enrol failures for SoC tooling to alert on brute-force, password-spray, and enroll abuse. pkg/utils/http-utils.go: SetTrustedProxies + an updated GetIP that honors the trusted-proxies set. Empty (default) ignores X-Forwarded-For / X-Real-IP entirely. == SQL hardening + carve path safety == pkg/carves/utils.go: new ValidCarvePath regexp gate. Without this gate a CarveLevel operator could pass \`'; SELECT 1; --\` and pivot 'carve a file' into 'run any SELECT against your fleet' via GenCarveQuery's string concat. cmd/api/handlers/carves.go (CarvesRunHandler): path validated before the SQL splice. Rejected paths return 400. == Authz + audit-log hardening == pkg/users: - bcrypt cost raised from default (10) to 12. CheckLoginCredentials opportunistically re-hashes existing users at next login (no password reset needed). Rehash failure is non-fatal. - New ClearToken empties APIToken AND CSRFToken so any existing JWT + CSRF cookie pair stops validating. Used by future DELETE /api/v1/users/{username}/token in a follow-up PR. cmd/api/handlers/{users,settings,environments}.go: authz tightenings around permission writes, settings PATCH, and env-action service-name validation. pkg/environments/env-cache.go: keep the 2h cleanup interval; introduce an envCacheTTL constant so the value is self-documenting and tunable locally without changing runtime defaults. == Defaults + ops == deploy/config/{api,admin}.yml: flip --audit-log default to true so audit log writes are on by default. Operators can disable with --audit-log=false. Verified: go build ./... clean, go vet ./... clean, go test ./pkg/... ./cmd/api/... ./cmd/tls/... all green. --- cmd/api/auth.go | 122 +++++++++++++++++-- cmd/api/auth_test.go | 88 ++++++++++++++ cmd/api/handlers/carves.go | 9 ++ cmd/api/handlers/environments.go | 168 ++++++++++++++++++++------ cmd/api/handlers/environments_test.go | 91 ++++++++++++++ cmd/api/handlers/login.go | 98 +++++++++++++-- cmd/api/handlers/queries.go | 8 ++ cmd/api/handlers/settings.go | 13 +- cmd/api/handlers/users.go | 41 +++++-- cmd/api/main.go | 57 ++++++++- deploy/config/admin.yml | 2 +- deploy/config/api.yml | 16 ++- go.mod | 1 + go.sum | 2 + pkg/auditlog/audit.go | 45 +++++++ pkg/carves/utils.go | 31 ++++- pkg/carves/utils_test.go | 51 ++++++++ pkg/config/flags.go | 15 ++- pkg/config/types.go | 5 + pkg/environments/env-cache.go | 27 ++++- pkg/environments/environments.go | 24 +++- pkg/ratelimit/ratelimit.go | 144 ++++++++++++++++++++++ pkg/ratelimit/ratelimit_test.go | 108 +++++++++++++++++ pkg/types/types.go | 59 ++++++++- pkg/users/permissions_test.go | 2 +- pkg/users/users.go | 94 ++++++++++++-- pkg/users/users_test.go | 72 ++++++++++- pkg/utils/http-utils.go | 132 ++++++++++++++++++-- pkg/utils/http-utils_test.go | 79 +++++++++++- 29 files changed, 1487 insertions(+), 117 deletions(-) create mode 100644 cmd/api/auth_test.go create mode 100644 cmd/api/handlers/environments_test.go create mode 100644 pkg/carves/utils_test.go create mode 100644 pkg/ratelimit/ratelimit.go create mode 100644 pkg/ratelimit/ratelimit_test.go diff --git a/cmd/api/auth.go b/cmd/api/auth.go index 4bee5551..3c357931 100644 --- a/cmd/api/auth.go +++ b/cmd/api/auth.go @@ -2,11 +2,13 @@ package main import ( "context" + "crypto/subtle" "net/http" "strings" "github.com/jmpsec/osctrl/cmd/api/handlers" "github.com/jmpsec/osctrl/pkg/config" + "github.com/jmpsec/osctrl/pkg/types" "github.com/jmpsec/osctrl/pkg/utils" "github.com/rs/zerolog/log" ) @@ -16,14 +18,79 @@ const ( contextAPI string = "osctrl-api-context" ) -// Helper to extract token from header +// Cookie + header names — kept in sync with cmd/api/handlers/login.go. +const ( + cookieNameToken = "osctrl_token" + cookieNameCSRF = "osctrl_csrf" + headerNameCSRF = "X-CSRF-Token" +) + +// Helper to extract token from the Authorization header first (CLI clients), +// falling back to the SPA's HttpOnly osctrl_token cookie. func extractHeaderToken(r *http.Request) string { - reqToken := r.Header.Get("Authorization") - splitToken := strings.Split(reqToken, "Bearer") - if len(splitToken) != 2 { - return "" + if v := r.Header.Get("Authorization"); v != "" { + splitToken := strings.Split(v, "Bearer") + if len(splitToken) == 2 { + if t := strings.TrimSpace(splitToken[1]); t != "" { + return t + } + } + } + if c, err := r.Cookie(cookieNameToken); err == nil { + return strings.TrimSpace(c.Value) + } + return "" +} + +// mutatingMethods is the set of HTTP verbs that must carry a valid CSRF token. +// GET/HEAD/OPTIONS are read-only and exempt. +var mutatingMethods = map[string]bool{ + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, + http.MethodDelete: true, +} + +// checkCSRF enforces the double-submit CSRF pattern on mutating requests. +// The SPA reads the non-HttpOnly osctrl_csrf cookie and echoes it via the +// X-CSRF-Token header on every mutation; we constant-time-compare: +// 1. header == cookie value (classic double-submit), AND +// 2. cookie value == AdminUser.CSRFToken (defeats a cookie-tossing +// attacker who can set both header and cookie without DB write access). +// +// CLI clients that authenticate purely via Authorization: Bearer (no cookie) +// are exempt — there is no browser to ride a cross-site request from. +// +// Note: AdminUser.CSRFToken rotates on every successful /login (see +// LoginHandler ↦ Users.UpdateMetadata). Concurrent logins of the same user +// race; the loser keeps a cookie that no longer matches the stored value +// and gets 403 on the next mutation. APIToken refresh / clear also clear +// CSRFToken (see pkg/users.UpdateToken / ClearToken) so a stale CSRF +// cookie cannot outlive its session. +func checkCSRF(r *http.Request, username string) bool { + // r.Cookie returns ErrNoCookie only when the cookie name is absent; + // an empty-value cookie returns (cookie, nil). Treating the empty case + // as "Bearer client" would bypass CSRF — instead, the call to + // extractHeaderToken upstream rejects empty-value cookies before we + // reach this function (the trimmed value falls through to "" return). + if _, err := r.Cookie(cookieNameToken); err != nil { + // No session cookie ⇒ Bearer-only client (CLI/CI). Nothing to CSRF. + return true + } + headerToken := strings.TrimSpace(r.Header.Get(headerNameCSRF)) + cookie, err := r.Cookie(cookieNameCSRF) + if err != nil || headerToken == "" { + return false + } + cookieValue := strings.TrimSpace(cookie.Value) + if subtle.ConstantTimeCompare([]byte(headerToken), []byte(cookieValue)) != 1 { + return false + } + user, err := apiUsers.Get(username) + if err != nil || user.CSRFToken == "" { + return false } - return strings.TrimSpace(splitToken[1]) + return subtle.ConstantTimeCompare([]byte(cookieValue), []byte(user.CSRFToken)) == 1 } // Handler to check access to a resource based on the authentication enabled @@ -41,12 +108,51 @@ func handlerAuthCheck(h http.Handler, auth, jwtSecret string) http.Handler { // Set middleware values token := extractHeaderToken(r) if token == "" { - http.Redirect(w, r, forbiddenPath, http.StatusForbidden) + if utils.AcceptsJSON(r) { + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusUnauthorized, + types.ApiErrorResponse{Error: "unauthorized", Code: "unauthorized"}) + return + } + // 302 is required by http.Redirect; the legacy 403 didn't actually trigger + // a redirect in any browser since http.Redirect demands a 3xx status. + http.Redirect(w, r, forbiddenPath, http.StatusFound) return } claims, valid := apiUsers.CheckToken(jwtSecret, token) if !valid { - http.Redirect(w, r, forbiddenPath, http.StatusForbidden) + if utils.AcceptsJSON(r) { + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusUnauthorized, + types.ApiErrorResponse{Error: "unauthorized", Code: "unauthorized"}) + return + } + // 302 is required by http.Redirect; the legacy 403 didn't actually trigger + // a redirect in any browser since http.Redirect demands a 3xx status. + http.Redirect(w, r, forbiddenPath, http.StatusFound) + return + } + // Match the presented token against the user's currently-stored APIToken + // so that refresh/delete on /users/{username}/token invalidates old JWTs. + // (CheckToken above only validates the signature.) Service users with no + // stored token are rejected immediately. Constant-time comparison guards + // against timing-side-channel leaks of the stored token. + user, uerr := apiUsers.Get(claims.Username) + tokenMatches := uerr == nil && user.APIToken != "" && + subtle.ConstantTimeCompare([]byte(user.APIToken), []byte(token)) == 1 + if !tokenMatches { + if utils.AcceptsJSON(r) { + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusUnauthorized, + types.ApiErrorResponse{Error: "unauthorized", Code: "unauthorized"}) + return + } + http.Redirect(w, r, forbiddenPath, http.StatusFound) + return + } + // CSRF guard for cookie-authenticated mutating requests. CLI Bearer + // clients are exempt via the cookieNameToken probe inside checkCSRF. + // + if mutatingMethods[r.Method] && !checkCSRF(r, claims.Username) { + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusForbidden, + types.ApiErrorResponse{Error: "csrf token missing or invalid", Code: "csrf"}) return } // Update metadata for the user diff --git a/cmd/api/auth_test.go b/cmd/api/auth_test.go new file mode 100644 index 00000000..965d369f --- /dev/null +++ b/cmd/api/auth_test.go @@ -0,0 +1,88 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/jmpsec/osctrl/pkg/config" +) + +func TestHandlerAuthCheckJSONvsRedirect(t *testing.T) { + // A no-op inner handler — handlerAuthCheck should never call it when + // there's no valid token. We just need to assert the failure response. + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("inner handler should not be called when auth fails") + }) + + h := handlerAuthCheck(inner, config.AuthJWT, "test-jwt-secret") + + t.Run("Accept application/json returns 401 JSON", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/anything", nil) + req.Header.Set("Accept", "application/json") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Fatalf("status: got %d, want 401", rr.Code) + } + ct := rr.Header().Get("Content-Type") + if ct == "" || ct[:16] != "application/json" { + t.Fatalf("Content-Type: got %q, want application/json...", ct) + } + }) + + t.Run("default client gets 302 redirect", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/anything", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusFound { + t.Fatalf("status: got %d, want 302", rr.Code) + } + if rr.Header().Get("Location") == "" { + t.Fatal("missing Location header on redirect") + } + }) +} + +func TestExtractHeaderTokenPrefersBearerThenCookie(t *testing.T) { + cases := []struct { + name string + header string + cookie string + want string + }{ + {"bearer header", "Bearer abc.def.ghi", "", "abc.def.ghi"}, + {"cookie fallback", "", "xyz.uvw.123", "xyz.uvw.123"}, + {"bearer wins over cookie", "Bearer header-token", "cookie-token", "header-token"}, + {"no auth at all", "", "", ""}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.header != "" { + req.Header.Set("Authorization", tc.header) + } + if tc.cookie != "" { + req.AddCookie(&http.Cookie{Name: cookieNameToken, Value: tc.cookie}) + } + got := extractHeaderToken(req) + if got != tc.want { + t.Fatalf("got %q, want %q", got, tc.want) + } + }) + } +} + +func TestMutatingMethodsTable(t *testing.T) { + // Lock the contract that GET/HEAD/OPTIONS bypass CSRF and PUT/PATCH/POST/DELETE require it. + for _, m := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} { + if mutatingMethods[m] { + t.Errorf("read-only method %s should not require CSRF", m) + } + } + for _, m := range []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} { + if !mutatingMethods[m] { + t.Errorf("mutating method %s must require CSRF", m) + } + } +} diff --git a/cmd/api/handlers/carves.go b/cmd/api/handlers/carves.go index f505d4d2..8d9889e4 100644 --- a/cmd/api/handlers/carves.go +++ b/cmd/api/handlers/carves.go @@ -189,6 +189,15 @@ func (h *HandlersApi) CarvesRunHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "path can not be empty", http.StatusInternalServerError, nil) return } + // Validate the path before it's spliced into the osquery SQL via + // carves.GenCarveQuery. Without this gate a CarveLevel operator + // could inject arbitrary osquery (e.g. `'; SELECT 1; --`) into the + // query that gets distributed to every targeted node — pivoting + // "carve a file" into "run any SELECT". + if !carves.ValidCarvePath(c.Path) { + apiErrorResponse(w, "invalid carve path", http.StatusBadRequest, fmt.Errorf("rejected path %q", c.Path)) + return + } // Make sure the user has permissions to run queries in the environments for _, e := range c.Environments { if !h.Users.CheckPermissions(ctx[ctxUser], users.QueryLevel, e) { diff --git a/cmd/api/handlers/environments.go b/cmd/api/handlers/environments.go index 50d84e89..6feb721d 100644 --- a/cmd/api/handlers/environments.go +++ b/cmd/api/handlers/environments.go @@ -25,6 +25,44 @@ var ( } ) +// denyEnv emits a 403 AND an audit-log entry pinned to the env handler's +// resource class. Used by the env-handler family for every deny branch +// so cross-tenant probes leave an SoC-alertable trail. The path comes +// from r.URL.Path; envID is 0 (NoEnvironment) when the deny happened +// before env resolution. +func (h *HandlersApi) denyEnv(w http.ResponseWriter, r *http.Request, ctx ContextValue, envID uint, reason string) { + h.AuditLog.Denied(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], reason, auditlog.LogTypeEnvironment, envID) + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("denied: %s for user %s", reason, ctx[ctxUser])) +} + +// projectEnvironmentView strips the env-secret-bearing fields from +// TLSEnvironment to produce the SPA-canonical low-privilege envelope. +// Callers MUST use this when serving env data to a non-admin (UserLevel / +// QueryLevel / CarveLevel) user. +func projectEnvironmentView(env environments.TLSEnvironment) types.TLSEnvironmentView { + return types.TLSEnvironmentView{ + ID: env.ID, + CreatedAt: env.CreatedAt, + UpdatedAt: env.UpdatedAt, + UUID: env.UUID, + Name: env.Name, + Hostname: env.Hostname, + Type: env.Type, + Icon: env.Icon, + DebugHTTP: env.DebugHTTP, + ConfigTLS: env.ConfigTLS, + ConfigInterval: env.ConfigInterval, + LoggingTLS: env.LoggingTLS, + LogInterval: env.LogInterval, + QueryTLS: env.QueryTLS, + QueryInterval: env.QueryInterval, + CarvesTLS: env.CarvesTLS, + AcceptEnrolls: env.AcceptEnrolls, + EnrollExpire: env.EnrollExpire, + RemoveExpire: env.RemoveExpire, + } +} + // EnvironmentHandler - GET Handler to return one environment by UUID as JSON func (h *HandlersApi) EnvironmentHandler(w http.ResponseWriter, r *http.Request) { // Debug HTTP if enabled @@ -50,13 +88,21 @@ func (h *HandlersApi) EnvironmentHandler(w http.ResponseWriter, r *http.Request) // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.UserLevel, env.UUID) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, env.ID, "permission check failed") return } - // Serialize and serve JSON - log.Debug().Msgf("Returned environment %s", env.Name) + // Decide projection by privilege level: admins on this env (or + // super-admins) receive the full storage struct including secret / + // certificate / flags. UserLevel operators receive the low-privilege + // view that omits enroll credentials. h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, env) + if h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { + log.Debug().Msgf("Returned environment %s (admin view)", env.Name) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, env) + return + } + log.Debug().Msgf("Returned environment %s (low-priv view)", env.Name) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, projectEnvironmentView(env)) } // EnvironmentMapHandler - GET Handler to return one environment as JSON @@ -79,7 +125,7 @@ func (h *HandlersApi) EnvironmentMapHandler(w http.ResponseWriter, r *http.Reque // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, auditlog.NoEnvironment, "permission check failed") return } // Prepare map by target @@ -112,7 +158,7 @@ func (h *HandlersApi) EnvironmentsHandler(w http.ResponseWriter, r *http.Request // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, auditlog.NoEnvironment, "permission check failed") return } // Get platforms @@ -149,10 +195,15 @@ func (h *HandlersApi) EnvEnrollHandler(w http.ResponseWriter, r *http.Request) { } return } - // Get context data and check access + // Get context data and check access. The enroll endpoint exposes the + // env's enroll secret (directly via target=secret, indirectly via the + // one-liners that embed it in the URL, and via target=flags). That + // secret is the only credential needed to enroll nodes via osctrl-tls, + // so it must be gated to AdminLevel on the env, not UserLevel. + // ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) - if !h.Users.CheckPermissions(ctx[ctxUser], users.UserLevel, env.UUID) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { + h.denyEnv(w, r, ctx, env.ID, "permission check failed") return } // Extract target @@ -185,8 +236,9 @@ func (h *HandlersApi) EnvEnrollHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "invalid target", http.StatusBadRequest, fmt.Errorf("invalid target %s", targetVar)) return } - // Serialize and serve JSON - log.Debug().Msgf("Returned data for environment%s : %s", env.Name, returnData) + // Serialize and serve JSON. Don't log the payload — it contains the + // enroll secret. + log.Debug().Msgf("Returned enroll data for environment %s target=%s", env.Name, targetVar) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiDataResponse{Data: returnData}) } @@ -213,10 +265,12 @@ func (h *HandlersApi) EnvRemoveHandler(w http.ResponseWriter, r *http.Request) { } return } - // Get context data and check access + // Get context data and check access. The remove one-liners embed the + // remove-secret in the URL, so the endpoint must be AdminLevel-gated + // just like the enroll variant. ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) - if !h.Users.CheckPermissions(ctx[ctxUser], users.UserLevel, env.UUID) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { + h.denyEnv(w, r, ctx, env.ID, "permission check failed") return } // Extract target @@ -243,8 +297,9 @@ func (h *HandlersApi) EnvRemoveHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "invalid target", http.StatusBadRequest, fmt.Errorf("invalid target %s", targetVar)) return } - // Serialize and serve JSON - log.Debug().Msgf("Returned data for environment %s : %s", env.Name, returnData) + // Serialize and serve JSON. Don't log the payload — it embeds the + // remove secret. + log.Debug().Msgf("Returned remove data for environment %s target=%s", env.Name, targetVar) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiDataResponse{Data: returnData}) } @@ -274,7 +329,7 @@ func (h *HandlersApi) EnvEnrollActionsHandler(w http.ResponseWriter, r *http.Req // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, env.ID, "permission check failed") return } // Extract action @@ -374,7 +429,7 @@ func (h *HandlersApi) EnvRemoveActionsHandler(w http.ResponseWriter, r *http.Req // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, env.ID, "permission check failed") return } // Extract action @@ -433,7 +488,7 @@ func (h *HandlersApi) EnvActionsHandler(w http.ResponseWriter, r *http.Request) // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, auditlog.NoEnvironment, "permission check failed") return } var e types.ApiEnvRequest @@ -450,6 +505,23 @@ func (h *HandlersApi) EnvActionsHandler(w http.ResponseWriter, r *http.Request) apiErrorResponse(w, "invalid data", http.StatusBadRequest, nil) return } + // Validate the optional client-supplied UUID strictly. + // - utils.CheckUUID delegates to google/uuid Parse, accepting only + // canonical UUIDs. EnvUUIDFilter alone is `^[a-z0-9-]+$`, which + // would have happily accepted "-", "a", "deadbeef", etc. + // - ExistsByUUID (vs the polymorphic Exists) ensures a UUID-collision + // check cannot match against an existing env's NAME. The old + // Exists(e.UUID) leaked information across axes. + if e.UUID != "" { + if !utils.CheckUUID(e.UUID) { + apiErrorResponse(w, "invalid uuid", http.StatusBadRequest, fmt.Errorf("rejected uuid %q", e.UUID)) + return + } + if h.Envs.ExistsByUUID(e.UUID) { + apiErrorResponse(w, "uuid already in use", http.StatusConflict, fmt.Errorf("uuid %q collides", e.UUID)) + return + } + } // Check if environment already exists if !h.Envs.Exists(e.Name) { env := h.Envs.Empty(e.Name, e.Hostname) @@ -481,18 +553,18 @@ func (h *HandlersApi) EnvActionsHandler(w http.ResponseWriter, r *http.Request) } // Create a tag for this new environment if !h.Tags.Exists(env.Name) { - if err := h.Tags.NewTag( - env.Name, - "Tag for environment "+env.Name, - "", - env.Icon, - ctx[ctxUser], - env.ID, - false, - tags.TagTypeEnv, - ""); err != nil { - msgReturn = fmt.Sprintf("error generating tag %s ", err.Error()) - return + if err := h.Tags.NewTag( + env.Name, + "Tag for environment "+env.Name, + "", + env.Icon, + ctx[ctxUser], + env.ID, + false, + tags.TagTypeEnv, + ""); err != nil { + msgReturn = fmt.Sprintf("error generating tag %s ", err.Error()) + return } } msgReturn = "environment created successfully" @@ -501,21 +573,37 @@ func (h *HandlersApi) EnvActionsHandler(w http.ResponseWriter, r *http.Request) return } case "delete": - // Verify request fields + // Validate both name and UUID strictly, then verify they refer to + // the SAME environment so the request can't authorise via one + // env's UUID while targeting another env by name. The previous + // shape (polymorphic Exists(e.UUID) → Delete(e.Name)) allowed + // that authorisation/target split. if !environments.EnvNameFilter(e.Name) { apiErrorResponse(w, "invalid environment name", http.StatusBadRequest, nil) return } - if h.Envs.Exists(e.UUID) { - if err := h.Envs.Delete(e.Name); err != nil { - apiErrorResponse(w, "error deleting environment", http.StatusInternalServerError, err) - return - } - msgReturn = "environment deleted successfully" - } else { - apiErrorResponse(w, "environment not found", http.StatusNotFound, fmt.Errorf("environment %s not found", e.Name)) + if e.UUID == "" { + apiErrorResponse(w, "missing environment UUID", http.StatusBadRequest, nil) + return + } + if !utils.CheckUUID(e.UUID) { + apiErrorResponse(w, "invalid environment UUID", http.StatusBadRequest, nil) + return + } + targetEnv, getErr := h.Envs.GetByUUID(e.UUID) + if getErr != nil { + apiErrorResponse(w, "environment not found", http.StatusNotFound, fmt.Errorf("environment %s not found", e.UUID)) + return + } + if targetEnv.Name != e.Name { + apiErrorResponse(w, "name does not match the environment with that UUID", http.StatusBadRequest, fmt.Errorf("uuid %s maps to name %q, body claims %q", e.UUID, targetEnv.Name, e.Name)) + return + } + if err := h.Envs.Delete(targetEnv.Name); err != nil { + apiErrorResponse(w, "error deleting environment", http.StatusInternalServerError, err) return } + msgReturn = "environment deleted successfully" case "edit": // Verify request fields if !environments.EnvUUIDFilter(e.UUID) { diff --git a/cmd/api/handlers/environments_test.go b/cmd/api/handlers/environments_test.go new file mode 100644 index 00000000..bbe332cf --- /dev/null +++ b/cmd/api/handlers/environments_test.go @@ -0,0 +1,91 @@ +package handlers + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/jmpsec/osctrl/pkg/environments" + "gorm.io/gorm" +) + +// TestProjectEnvironmentViewStripsSecrets is the load-bearing regression test +// for the env-secret-containment fix. projectEnvironmentView returns the SPA +// envelope served to UserLevel operators; if a future contributor adds a new +// secret-bearing field to TLSEnvironment without extending the projection, +// the field will leak into the low-priv response. This test marshals the +// projection from a fully-populated source struct and asserts every +// known-sensitive substring is absent from the serialized JSON. +func TestProjectEnvironmentViewStripsSecrets(t *testing.T) { + src := environments.TLSEnvironment{ + Model: gorm.Model{ + ID: 1, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + UUID: "11111111-2222-3333-4444-555555555555", + Name: "prod", + Hostname: "osctrl.example.com", + Type: "dev", + Icon: "rocket", + // The fields below must NOT appear in the projection. + Secret: "SECRET-MARKER-enroll", + EnrollSecretPath: "SECRET-MARKER-enroll-path", + RemoveSecretPath: "SECRET-MARKER-remove-path", + Certificate: "SECRET-MARKER-cert", + Flags: "SECRET-MARKER-flags", + Options: "SECRET-MARKER-options", + Schedule: "SECRET-MARKER-schedule", + Packs: "SECRET-MARKER-packs", + Decorators: "SECRET-MARKER-decorators", + ATC: "SECRET-MARKER-atc", + Configuration: "SECRET-MARKER-configuration", + DebPackage: "SECRET-MARKER-deb", + RpmPackage: "SECRET-MARKER-rpm", + MsiPackage: "SECRET-MARKER-msi", + PkgPackage: "SECRET-MARKER-pkg", + EnrollPath: "SECRET-MARKER-enroll-route", + LogPath: "SECRET-MARKER-log-route", + ConfigPath: "SECRET-MARKER-config-route", + QueryReadPath: "SECRET-MARKER-qread-route", + QueryWritePath: "SECRET-MARKER-qwrite-route", + CarverInitPath: "SECRET-MARKER-carver-init", + CarverBlockPath: "SECRET-MARKER-carver-block", + UserID: 42, + // Operational fields that ARE expected in the view: + ConfigInterval: 60, + LogInterval: 30, + QueryInterval: 10, + AcceptEnrolls: true, + } + + view := projectEnvironmentView(src) + out, err := json.Marshal(view) + if err != nil { + t.Fatalf("marshal: %v", err) + } + body := string(out) + + // Field set + tag names assertions. + wantFields := []string{ + `"uuid":"11111111-2222-3333-4444-555555555555"`, + `"name":"prod"`, + `"hostname":"osctrl.example.com"`, + `"icon":"rocket"`, + `"config_interval":60`, + `"log_interval":30`, + `"query_interval":10`, + `"accept_enrolls":true`, + } + for _, w := range wantFields { + if !strings.Contains(body, w) { + t.Errorf("expected %q in view JSON, got: %s", w, body) + } + } + + // Every SECRET-MARKER must be absent. + if strings.Contains(body, "SECRET-MARKER") { + t.Fatalf("view leaked at least one secret-bearing field: %s", body) + } +} diff --git a/cmd/api/handlers/login.go b/cmd/api/handlers/login.go index 4926890a..8eb75752 100644 --- a/cmd/api/handlers/login.go +++ b/cmd/api/handlers/login.go @@ -1,9 +1,12 @@ package handlers import ( + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" "net/http" + "time" "github.com/jmpsec/osctrl/pkg/types" "github.com/jmpsec/osctrl/pkg/users" @@ -22,10 +25,13 @@ func (h *HandlersApi) LoginHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) return } - // Get environment by UUID - env, err := h.Envs.GetByUUID(envVar) + // Resolve environment by name OR UUID. The SPA login form lets users type + // the env name ("dev", "prod") because UUIDs are not memorable; the API + // must accept either. Get() uses `name = ? OR uuid = ?` so both shapes + // resolve to the same row. A miss returns 404, not 500. + env, err := h.Envs.Get(envVar) if err != nil { - apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) + apiErrorResponse(w, "environment not found", http.StatusNotFound, nil) return } var l types.ApiLoginRequest @@ -34,31 +40,101 @@ func (h *HandlersApi) LoginHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "error parsing POST body", http.StatusInternalServerError, err) return } - // Check credentials + // Check credentials. Audit-log every credential failure so SoC tooling + // has a stream to alert on (brute-force, password spray). The IP comes + // from utils.GetIP so X-Real-IP / X-Forwarded-For behind a reverse + // proxy is honored. access, user := h.Users.CheckLoginCredentials(l.Username, l.Password) if !access { + h.AuditLog.FailedLogin(l.Username, utils.GetIP(r), "invalid credentials") apiErrorResponse(w, "invalid credentials", http.StatusForbidden, err) return } // Check if user has access to this environment if !h.Users.CheckPermissions(l.Username, users.AdminLevel, env.UUID) { + h.AuditLog.FailedLogin(l.Username, utils.GetIP(r), fmt.Sprintf("no admin access to env %s", env.UUID)) apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use %s by user %s", h.ServiceName, l.Username)) return } - // Do we have a token already? - if user.APIToken == "" { - token, exp, err := h.Users.CreateToken(l.Username, h.ServiceName, l.ExpHours) + // Decide whether to reuse the stored token or mint a fresh one. Re-issue + // when there's no token, when the stored token has already expired (the + // reuse path used to return 500 "token already expired" — a regression + // that locked users out after their first session expired), or when the + // stored token is within 60s of expiring so we don't hand out something + // that will fail mid-request. + var tokenExp time.Time + now := time.Now() + const freshnessWindow = 60 * time.Second + needsRefresh := user.APIToken == "" || user.TokenExpire.Before(now.Add(freshnessWindow)) + if needsRefresh { + var token string + token, tokenExp, err = h.Users.CreateToken(l.Username, h.ServiceName, l.ExpHours) if err != nil { apiErrorResponse(w, "error creating token", http.StatusInternalServerError, err) return } - if err = h.Users.UpdateToken(l.Username, token, exp); err != nil { + if err = h.Users.UpdateToken(l.Username, token, tokenExp); err != nil { apiErrorResponse(w, "error updating token", http.StatusInternalServerError, err) return } user.APIToken = token + } else { + tokenExp = user.TokenExpire } - h.AuditLog.NewLogin(l.Username, r.RemoteAddr) - // Serialize and serve JSON - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiLoginResponse{Token: user.APIToken}) + // Generate a CSRF token: 16 random bytes encoded as 32 hex chars. + // This cookie is NOT HttpOnly so the SPA can read it and echo it back + // via the X-CSRF-Token header on mutating requests. + csrfBytes := make([]byte, 16) + if _, err = rand.Read(csrfBytes); err != nil { + apiErrorResponse(w, "error generating csrf token", http.StatusInternalServerError, err) + return + } + csrfToken := hex.EncodeToString(csrfBytes) + // Persist the CSRF token alongside the user so the auth middleware can + // verify subsequent X-CSRF-Token headers. Without this write the SPA's + // double-submit pattern is purely cosmetic. + // IP comes from utils.GetIP so it matches the format every other site + // writes to last_ip_address (clean IP, X-Real-IP / X-Forwarded-For aware). + clientIP := utils.GetIP(r) + if err := h.Users.UpdateMetadata(clientIP, r.UserAgent(), l.Username, csrfToken); err != nil { + apiErrorResponse(w, "error persisting csrf token", http.StatusInternalServerError, err) + return + } + // Compute cookie Max-Age from token expiry. + maxAge := int(time.Until(tokenExp).Seconds()) + if maxAge <= 0 { + apiErrorResponse(w, "token already expired", http.StatusInternalServerError, fmt.Errorf("token expiry in past or zero: %v", tokenExp)) + return + } + // Set the httpOnly session cookie. The SPA reads the JWT via the cookie; + // it never needs to access this cookie from JS. + // Secure: true requires HTTPS. If TLS is terminated at a proxy that speaks + // plain HTTP to this service, set Secure:false in the proxy's cookie rewrite + // rule — do not add an --insecure-cookies flag to keep the surface small. + http.SetCookie(w, &http.Cookie{ + Name: "osctrl_token", + Value: user.APIToken, + Path: "/", + MaxAge: maxAge, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + // Set the CSRF cookie (not HttpOnly — SPA must read it). + http.SetCookie(w, &http.Cookie{ + Name: "osctrl_csrf", + Value: csrfToken, + Path: "/", + MaxAge: maxAge, + HttpOnly: false, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + h.AuditLog.NewLogin(l.Username, clientIP) + // Serialize and serve JSON. Token stays in the body for backward compat + // with CLI consumers that do not use cookies. + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiLoginResponse{ + Token: user.APIToken, + CSRFToken: csrfToken, + }) } diff --git a/cmd/api/handlers/queries.go b/cmd/api/handlers/queries.go index 93afa0b6..36f341a5 100644 --- a/cmd/api/handlers/queries.go +++ b/cmd/api/handlers/queries.go @@ -372,6 +372,14 @@ func (h *HandlersApi) QueryResultsHandler(w http.ResponseWriter, r *http.Request apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } + // Verify the named query belongs to THIS env. logging.GetQueryResults + // filters on `name` only — without this gate a user with QueryLevel on + // env A could pull results from env B by passing B's query name in + // A's URL. + if !h.Queries.Exists(name, env.ID) { + apiErrorResponse(w, "query not found", http.StatusNotFound, nil) + return + } // Get query by name // TODO this is a temporary solution, we need to refactor this and take into consideration the // logger for TLS and whether if the results are stored in the DB or a different DB diff --git a/cmd/api/handlers/settings.go b/cmd/api/handlers/settings.go index 5c9569f1..985fbabd 100644 --- a/cmd/api/handlers/settings.go +++ b/cmd/api/handlers/settings.go @@ -110,8 +110,11 @@ func (h *HandlersApi) SettingsServiceEnvHandler(w http.ResponseWriter, r *http.R apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get settings - serviceSettings, err := h.Settings.RetrieveValues(service, false, settings.NoEnvironmentID) + // Get settings scoped to THIS env. Previously this passed + // NoEnvironmentID and silently returned global settings, which let an + // env-X admin read another env's values as a side-channel via the + // env-scoped route. + serviceSettings, err := h.Settings.RetrieveValues(service, false, env.ID) if err != nil { apiErrorResponse(w, "error getting settings", http.StatusInternalServerError, err) return @@ -196,8 +199,10 @@ func (h *HandlersApi) SettingsServiceEnvJSONHandler(w http.ResponseWriter, r *ht apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get settings - serviceSettings, err := h.Settings.RetrieveValues(service, true, settings.NoEnvironmentID) + // Get settings scoped to THIS env. Same defense as + // SettingsServiceEnvHandler above; was silently returning global + // settings via NoEnvironmentID. + serviceSettings, err := h.Settings.RetrieveValues(service, true, env.ID) if err != nil { apiErrorResponse(w, "error getting settings", http.StatusInternalServerError, err) return diff --git a/cmd/api/handlers/users.go b/cmd/api/handlers/users.go index 759f80c4..7823defb 100644 --- a/cmd/api/handlers/users.go +++ b/cmd/api/handlers/users.go @@ -13,6 +13,26 @@ import ( "github.com/rs/zerolog/log" ) +// projectAdminUserView strips network-and-timing metadata +// (LastIPAddress / LastUserAgent / LastAccess / LastTokenUse) from an +// AdminUser before serialization to a cross-user reader. Operators +// querying their own row use /api/v1/users/me's full UserMeResponse. +func projectAdminUserView(u users.AdminUser) types.AdminUserView { + return types.AdminUserView{ + ID: u.ID, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, + Username: u.Username, + Email: u.Email, + Fullname: u.Fullname, + Admin: u.Admin, + Service: u.Service, + UUID: u.UUID, + TokenExpire: u.TokenExpire, + EnvironmentID: u.EnvironmentID, + } +} + // UserHandler - GET Handler for environment users func (h *HandlersApi) UserHandler(w http.ResponseWriter, r *http.Request) { // Debug HTTP if enabled @@ -37,10 +57,12 @@ func (h *HandlersApi) UserHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "error getting user", http.StatusInternalServerError, nil) return } - // Serialize and serve JSON + // Serialize and serve the PII-minimized view; the full user record + // is only available to the user themselves via /api/v1/users/me. + // log.Debug().Msgf("Returned user %s", usernameVar) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], auditlog.NoEnvironment) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, user) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, projectAdminUserView(user)) } // UsersHandler - GET Handler for multiple JSON nodes @@ -56,19 +78,24 @@ func (h *HandlersApi) UsersHandler(w http.ResponseWriter, r *http.Request) { return } // Get users - users, err := h.Users.All() + all, err := h.Users.All() if err != nil { apiErrorResponse(w, "error getting users", http.StatusInternalServerError, err) return } - if len(users) == 0 { + if len(all) == 0 { apiErrorResponse(w, "no users", http.StatusNotFound, nil) return } - // Serialize and serve JSON - log.Debug().Msgf("Returned %d users", len(users)) + // PII-minimized view for the cross-user list — see projectAdminUserView. + // + views := make([]types.AdminUserView, 0, len(all)) + for _, u := range all { + views = append(views, projectAdminUserView(u)) + } + log.Debug().Msgf("Returned %d users", len(views)) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], auditlog.NoEnvironment) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, users) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, views) } // UserActionHandler - POST Handler to take actions on a user by username and environment diff --git a/cmd/api/main.go b/cmd/api/main.go index dc569d02..231f7e3e 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -21,9 +21,11 @@ import ( "github.com/jmpsec/osctrl/pkg/logging" "github.com/jmpsec/osctrl/pkg/nodes" "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/ratelimit" "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/tags" "github.com/jmpsec/osctrl/pkg/users" + "github.com/jmpsec/osctrl/pkg/utils" "github.com/jmpsec/osctrl/pkg/version" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -185,8 +187,49 @@ func checkLatestRelease() { } } +// guardAuthMode refuses to start the API with --auth=none unless the operator +// explicitly opts in via OSCTRL_INSECURE_NO_AUTH=1. When the opt-in is set, +// every 60s a loud warning is logged so the deployment cannot drift into +// "auth-off forever" without anyone noticing. +// +// The warning goroutine watches the supplied context so a future graceful +// shutdown path can cancel it cleanly. Today the API has no shutdown signal +// handling so the context never fires — that's acceptable; we get the +// no-leak property for free when shutdown is added. +func guardAuthMode(ctx context.Context, auth string) { + if auth != config.AuthNone { + return + } + if os.Getenv("OSCTRL_INSECURE_NO_AUTH") != "1" { + log.Fatal().Msg("auth=none is disabled by default. Set OSCTRL_INSECURE_NO_AUTH=1 to opt in for local development only — every request will be served as super-admin") + } + go func() { + log.Warn().Msg("INSECURE: osctrl-api running with auth=none — every request is served as super-admin. DO NOT use in production") + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + log.Warn().Msg("INSECURE: osctrl-api running with auth=none — every request is served as super-admin. DO NOT use in production") + } + } + }() +} + // Go go! func osctrlAPIService() { + // Refuse to run unauthenticated unless the operator explicitly opts in. + guardAuthMode(context.Background(), flagParams.Service.Auth) + // Configure forwarding-header trust. Empty (default) means utils.GetIP + // ignores X-Forwarded-For / X-Real-IP and always uses RemoteAddr, so + // an internet attacker can't spoof IPs to defeat rate-limits or + // poison the audit log. + if tp := strings.TrimSpace(flagParams.Service.TrustedProxies); tp != "" { + utils.SetTrustedProxies(strings.Split(tp, ",")) + log.Info().Msgf("Trusting forwarding headers from: %s", tp) + } // ////////////////////////////// Backend log.Info().Msg("Initializing backend...") for { @@ -265,7 +308,6 @@ func osctrlAPIService() { handlers.WithAuditLog(auditLog), handlers.WithDebugHTTP(flagParams.Debug), handlers.WithOsqueryValues(*flagParams.Osquery), - ) // ///////////////////////// API @@ -284,7 +326,16 @@ func osctrlAPIService() { muxAPI.HandleFunc("GET "+_apiPath(checksNoAuthPath), handlersApi.CheckHandlerNoAuth) // ///////////////////////// UNAUTHENTICATED - muxAPI.HandleFunc("POST "+_apiPath(apiLoginPath)+"/{env}", handlersApi.LoginHandler) + // Login is the only password-acceptance surface on the API. Cap to + // 10 attempts per IP per minute (token-bucket; bursts of 10, refill + // at 1/6s) and 429 the rest. Rejections are audit-logged inside the + // LoginHandler / RateLimit middleware so SoC tooling sees the spray. + // + loginLimiter := ratelimit.New(10, time.Minute, 10*time.Minute) + loginRateLimit := loginLimiter.HTTPMiddleware(ratelimit.KeyByIP, func(r *http.Request, key string) { + handlersApi.AuditLog.FailedLogin("", utils.GetIP(r), "rate limit exceeded") + }) + muxAPI.Handle("POST "+_apiPath(apiLoginPath)+"/{env}", loginRateLimit(http.HandlerFunc(handlersApi.LoginHandler))) // ///////////////////////// AUTHENTICATED // API: check auth muxAPI.Handle( @@ -392,7 +443,7 @@ func osctrlAPIService() { handlerAuthCheck(http.HandlerFunc(handlersApi.EnvEnrollActionsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "GET "+_apiPath(apiEnvironmentsPath)+"/{env}/remove/{target}", - handlerAuthCheck(http.HandlerFunc(handlersApi.EnvironmentHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + handlerAuthCheck(http.HandlerFunc(handlersApi.EnvRemoveHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "POST "+_apiPath(apiEnvironmentsPath)+"/{env}/remove/{action}", handlerAuthCheck(http.HandlerFunc(handlersApi.EnvRemoveActionsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) diff --git a/deploy/config/admin.yml b/deploy/config/admin.yml index 1099e916..a5b3d155 100644 --- a/deploy/config/admin.yml +++ b/deploy/config/admin.yml @@ -10,7 +10,7 @@ service: host: osctrl.net # Valid values: "none", "json", "db", "saml", "oidc", "oauth" auth: none - auditLog: false + auditLog: true # Database configuration db: diff --git a/deploy/config/api.yml b/deploy/config/api.yml index e77c8c1f..9e90fba7 100644 --- a/deploy/config/api.yml +++ b/deploy/config/api.yml @@ -8,9 +8,19 @@ service: # Valid values: "json", "console" logFormat: json host: osctrl.net - # Valid values: "none", "json", "db", "saml", "oidc", "oauth" - auth: none - auditLog: false + # Valid values: "jwt", "none". `none` requires OSCTRL_INSECURE_NO_AUTH=1 + # in the environment and is intended for local-dev only — it impersonates + # super-admin on every request. Production deployments MUST use `jwt`. + auth: jwt + auditLog: true + # Comma-separated CIDR list whose X-Real-IP / X-Forwarded-For headers + # utils.GetIP will trust. Leave empty (default) when osctrl-api is + # directly internet-facing — forwarding headers are then ignored and + # RemoteAddr is used verbatim, preventing header-spoofed rate-limit + # bypass and audit-log poisoning. Set to your edge proxy's CIDR(s) + # when osctrl-api sits behind a trusted reverse proxy (e.g. + # `10.0.0.0/8` or `192.0.2.1/32,2001:db8::/64`). + trustedProxies: "" # Database configuration db: diff --git a/go.mod b/go.mod index ac619ec2..fb4fcb49 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( golang.org/x/oauth2 v0.36.0 golang.org/x/term v0.42.0 golang.org/x/text v0.36.0 + golang.org/x/time v0.15.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gorm.io/driver/mysql v1.6.0 gorm.io/driver/postgres v1.6.0 diff --git a/go.sum b/go.sum index 888ff6fc..b15d7339 100644 --- a/go.sum +++ b/go.sum @@ -230,6 +230,8 @@ golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/auditlog/audit.go b/pkg/auditlog/audit.go index 29dc5487..bd1d3246 100644 --- a/pkg/auditlog/audit.go +++ b/pkg/auditlog/audit.go @@ -114,6 +114,35 @@ func (m *AuditLogManager) NewLogin(username, ip string) { } } +// FailedLogin records a failed login attempt — invalid credentials, missing +// permission, or any other reason the login flow refused to mint a token. +// `reason` is a short free-text string suitable for SoC alerting and MUST +// NOT contain the offered password. Severity warning so it sticks out next +// to the successful-login firehose. +func (m *AuditLogManager) FailedLogin(username, ip, reason string) { + if !m.Enabled { + return + } + line := fmt.Sprintf("failed login for user %s: %s", username, reason) + if err := m.CreateNew(username, line, ip, LogTypeLogin, SeverityWarning, NoEnvironment); err != nil { + log.Err(err).Msg("error creating failed-login audit log") + } +} + +// FailedEnroll records a failed osquery-node enrollment attempt — invalid +// env secret, denied env, malformed payload. Severity warning, scoped to +// the env in the path (envID == 0 when the env itself was the failure +// reason). +func (m *AuditLogManager) FailedEnroll(ip, envName, reason string, envID uint) { + if !m.Enabled { + return + } + line := fmt.Sprintf("failed enroll for env %s: %s", envName, reason) + if err := m.CreateNew("osctrl-tls", line, ip, LogTypeNode, SeverityWarning, envID); err != nil { + log.Err(err).Msg("error creating failed-enroll audit log") + } +} + // NewLogout - create new logout audit log entry func (m *AuditLogManager) NewLogout(username, ip string) { if !m.Enabled { @@ -224,6 +253,22 @@ func (m *AuditLogManager) EnvAction(username, action, ip string, envID uint) { } } +// Denied records a 403/forbidden access attempt at SeverityWarning so SoC +// dashboards can surface cross-tenant probes. logType pins the resource +// class (LogTypeEnvironment for env handlers, LogTypeNode for node +// handlers, etc.). envID is the env the resource lives in, or +// NoEnvironment when the deny happened before env resolution. The reason +// field is short free text — never echo back the offered credential. +func (m *AuditLogManager) Denied(username, path, ip, reason string, logType, envID uint) { + if !m.Enabled { + return + } + line := fmt.Sprintf("denied access for user %s to %s: %s", username, path, reason) + if err := m.CreateNew(username, line, ip, logType, SeverityWarning, envID); err != nil { + log.Err(err).Msg("error creating denied-access audit log") + } +} + // SettingsAction - create new settings action audit log entry func (m *AuditLogManager) SettingsAction(username, action, ip string) { if !m.Enabled { diff --git a/pkg/carves/utils.go b/pkg/carves/utils.go index d800ecf3..bd224a96 100644 --- a/pkg/carves/utils.go +++ b/pkg/carves/utils.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/base64" "fmt" + "regexp" "strings" "github.com/jmpsec/osctrl/pkg/utils" @@ -78,7 +79,35 @@ func GenCarveName() string { return "carve_" + utils.RandomForNames() } -// Helper to generate the carve query +// validCarvePath restricts the characters that can appear in a carve +// path. The carve string is concatenated into the osquery SQL that +// every targeted node executes; without this gate a CarveLevel +// operator could inject arbitrary osquery (e.g. `'; SELECT 1; --`) and +// pivot from "exfil this path" to "run any SELECT against your nodes". +// +// The character class covers realistic carve targets across the three +// platforms: absolute POSIX paths (Linux/macOS), Windows paths with +// backslashes and drive letters, and glob wildcards (* and ?). It +// explicitly excludes single quote, semicolon, and comment markers. +var validCarvePath = regexp.MustCompile(`^[/A-Za-z0-9._\-\\:*?]+$`) + +// ValidCarvePath reports whether s is a safe value to splice into +// GenCarveQuery. Callers MUST verify before calling GenCarveQuery — +// the result is interpolated directly into SQL. +func ValidCarvePath(s string) bool { + if s == "" { + return false + } + return validCarvePath.MatchString(s) +} + +// Helper to generate the carve query. +// +// `file` is interpolated into the SQL string verbatim. The caller MUST +// have validated it via ValidCarvePath beforehand — passing an +// unvalidated user-controlled value here lets the requesting operator +// run arbitrary osquery on every targeted host, which is well beyond +// the "carve a file" capability the endpoint advertises. func GenCarveQuery(file string, glob bool) string { if glob { return "SELECT * FROM carves WHERE carve=1 AND path LIKE '" + file + "';" diff --git a/pkg/carves/utils_test.go b/pkg/carves/utils_test.go new file mode 100644 index 00000000..03824410 --- /dev/null +++ b/pkg/carves/utils_test.go @@ -0,0 +1,51 @@ +package carves + +import ( + "strings" + "testing" +) + +// TestValidCarvePath locks the character allowlist that gates GenCarveQuery. +func TestValidCarvePath(t *testing.T) { + good := []string{ + "/etc/passwd", + "/var/log/auth.log", + "C:\\Windows\\System32\\drivers\\etc\\hosts", + "/Users/alice/Library/Application_Support/com.example/cfg", + "/var/log/*.log", + "/var/log/auth?.log", + } + for _, p := range good { + if !ValidCarvePath(p) { + t.Errorf("ValidCarvePath(%q): expected true", p) + } + } + bad := []string{ + "", + "'; SELECT 1; --", + "/var/log/a'b", + "/var/log/a;b", + "/var/log/a b", // space + "/var/log/a\"b", + "/var/log/a\nb", + } + for _, p := range bad { + if ValidCarvePath(p) { + t.Errorf("ValidCarvePath(%q): expected false", p) + } + } +} + +// TestGenCarveQueryShape sanity-checks the SQL shape for both glob and +// exact match. Real callers MUST validate file via ValidCarvePath first; +// this test exercises the happy path only. +func TestGenCarveQueryShape(t *testing.T) { + q1 := GenCarveQuery("/etc/passwd", false) + if !strings.Contains(q1, "path = '/etc/passwd'") { + t.Errorf("exact: got %q", q1) + } + q2 := GenCarveQuery("/var/log/*.log", true) + if !strings.Contains(q2, "path LIKE '/var/log/*.log'") { + t.Errorf("glob: got %q", q2) + } +} diff --git a/pkg/config/flags.go b/pkg/config/flags.go index eb48dede..96fbc053 100644 --- a/pkg/config/flags.go +++ b/pkg/config/flags.go @@ -194,8 +194,8 @@ func initServiceFlags(params *ServiceParameters) []cli.Flag { &cli.StringFlag{ Name: "auth", Aliases: []string{"A"}, - Value: AuthNone, - Usage: "Authentication mechanism for the service", + Value: AuthJWT, + Usage: "Authentication mechanism for the service (jwt|none — `none` requires OSCTRL_INSECURE_NO_AUTH=1)", Sources: cli.EnvVars("SERVICE_AUTH"), Destination: ¶ms.Service.Auth, }, @@ -216,11 +216,18 @@ func initServiceFlags(params *ServiceParameters) []cli.Flag { &cli.BoolFlag{ Name: "audit-log", Aliases: []string{"audit"}, - Value: false, - Usage: "Enable audit log for the service. Logs all sensitive actions", + Value: true, + Usage: "Enable audit log for the service. Logs sensitive actions (logins, env mutations, query/carve runs, etc.). Disable only for local dev — production deployments MUST keep this on so SoC tooling has a stream to alert on.", Sources: cli.EnvVars("AUDIT_LOG"), Destination: ¶ms.Service.AuditLog, }, + &cli.StringFlag{ + Name: "trusted-proxies", + Value: "", + Usage: "Comma-separated CIDR list whose X-Real-IP / X-Forwarded-For headers will be honored. Empty (default) ignores forwarding headers and uses RemoteAddr verbatim — prevents header-spoofed rate-limit bypass and audit-log poisoning.", + Sources: cli.EnvVars("SERVICE_TRUSTED_PROXIES"), + Destination: ¶ms.Service.TrustedProxies, + }, } } diff --git a/pkg/config/types.go b/pkg/config/types.go index 1c4295ca..0f64d5e5 100644 --- a/pkg/config/types.go +++ b/pkg/config/types.go @@ -120,6 +120,11 @@ type YAMLConfigurationService struct { Host string `yaml:"host"` Auth string `yaml:"auth"` AuditLog bool `yaml:"auditLog"` + // TrustedProxies is a comma-separated list of CIDRs whose + // X-Real-IP / X-Forwarded-For headers utils.GetIP will honor. + // Default empty → forwarding headers are ignored and the + // connection's RemoteAddr is used. + TrustedProxies string `yaml:"trustedProxies"` } // YAMLConfigurationDB to hold all backend configuration values diff --git a/pkg/environments/env-cache.go b/pkg/environments/env-cache.go index 1f66b843..31226ef1 100644 --- a/pkg/environments/env-cache.go +++ b/pkg/environments/env-cache.go @@ -9,6 +9,19 @@ import ( const ( cacheName = "environments" + // envCacheTTL is the maximum time a TLSEnvironment can sit in the + // EnvCache before the next request refetches from the database. + // + // osctrl-tls holds this cache; osctrl-api mutates env rows in the + // same DB from a different process. There is no IPC channel between + // the two, so envCache invalidation is TTL-based — the TTL bounds + // the window during which enroll-secret rotations, env deletions, + // or config-PATCH changes can be served stale by osctrl-tls. + // + // Kept at the historical 2h cleanup interval; operators who need + // faster invalidation can rotate via `osctrl-tls` restart or tune + // this constant locally. + envCacheTTL = 2 * time.Hour ) // EnvCache provides cached access to TLS environments @@ -22,9 +35,8 @@ type EnvCache struct { // NewEnvCache creates a new environment cache func NewEnvCache(envs EnvManager) *EnvCache { - // Create a new cache with a 10-minute cleanup interval envCache := cache.NewMemoryCache( - cache.WithCleanupInterval[TLSEnvironment](2*time.Hour), + cache.WithCleanupInterval[TLSEnvironment](envCacheTTL), cache.WithName[TLSEnvironment](cacheName), ) @@ -47,24 +59,27 @@ func (ec *EnvCache) GetByUUID(ctx context.Context, uuid string) (TLSEnvironment, return TLSEnvironment{}, err } - ec.cache.Set(ctx, uuid, env, 2*time.Hour) + ec.cache.Set(ctx, uuid, env, envCacheTTL) return env, nil } -// InvalidateEnv removes a specific environment from the cache +// InvalidateEnv removes a specific environment from the cache. Callers +// that mutate env rows in the same process SHOULD invoke this so the +// next request refetches the row without waiting for the TTL. func (ec *EnvCache) InvalidateEnv(ctx context.Context, uuid string) { ec.cache.Delete(ctx, uuid) } -// InvalidateAll clears the entire cache +// InvalidateAll clears the entire cache. Used on bulk operations or +// after operator-driven secret rotations. func (ec *EnvCache) InvalidateAll(ctx context.Context) { ec.cache.Clear(ctx) } // UpdateEnvInCache updates an environment in the cache func (ec *EnvCache) UpdateEnvInCache(ctx context.Context, env TLSEnvironment) { - ec.cache.Set(ctx, env.UUID, env, 2*time.Hour) + ec.cache.Set(ctx, env.UUID, env, envCacheTTL) } // Close stops the cleanup goroutine and releases resources diff --git a/pkg/environments/environments.go b/pkg/environments/environments.go index a419e382..848cece5 100644 --- a/pkg/environments/environments.go +++ b/pkg/environments/environments.go @@ -214,13 +214,35 @@ func (environment *EnvManager) Create(env *TLSEnvironment) error { return nil } -// Exists checks if TLS Environment exists already +// Exists checks if TLS Environment exists already by name OR uuid (polymorphic). +// Prefer ExistsByUUID / ExistsByName when the caller knows which axis to check — +// the polymorphic variant can confuse a UUID-collision check with a name match +// and vice versa, which leaked information across axes in EnvActionsHandler. +// (Cluster-4 review item — see ExistsByUUID below.) func (environment *EnvManager) Exists(identifier string) bool { var results int64 environment.DB.Model(&TLSEnvironment{}).Where("name = ? OR uuid = ?", identifier, identifier).Count(&results) return (results > 0) } +// ExistsByUUID checks if a TLS Environment exists by UUID only. +// Use this when validating a client-supplied UUID for collision before +// creating a new environment, or for unambiguous delete-by-UUID semantics. +func (environment *EnvManager) ExistsByUUID(uuid string) bool { + var results int64 + environment.DB.Model(&TLSEnvironment{}).Where("uuid = ?", uuid).Count(&results) + return (results > 0) +} + +// ExistsByName checks if a TLS Environment exists by name only. +// (Companion to ExistsByUUID — provided for symmetry; callers preferring the +// polymorphic Exists() can keep using it.) +func (environment *EnvManager) ExistsByName(name string) bool { + var results int64 + environment.DB.Model(&TLSEnvironment{}).Where("name = ?", name).Count(&results) + return (results > 0) +} + // ExistsGet checks if TLS Environment exists already and returns it func (environment *EnvManager) ExistsGet(identifier string) (bool, TLSEnvironment) { e, err := environment.Get(identifier) diff --git a/pkg/ratelimit/ratelimit.go b/pkg/ratelimit/ratelimit.go new file mode 100644 index 00000000..85a4eb89 --- /dev/null +++ b/pkg/ratelimit/ratelimit.go @@ -0,0 +1,144 @@ +// Package ratelimit provides a small token-bucket rate-limit middleware +// used to protect anonymous attack surfaces (login, enroll) from +// brute-force / password-spray. +// +// The Limiter is keyed by a caller-supplied function (IP, IP+username, +// etc.) so the same primitive can fan out to per-endpoint policies. +package ratelimit + +import ( + "net/http" + "sync" + "time" + + "github.com/jmpsec/osctrl/pkg/utils" + "golang.org/x/time/rate" +) + +// DefaultMaxBuckets is the cap on the per-key map size. Once exceeded, +// new keys all share a single overflow bucket, so an attacker churning +// arbitrary keys (X-Forwarded-For spoofing or a similar primitive in a +// future surface) cannot grow the limiter's memory footprint unbounded. +const DefaultMaxBuckets = 100_000 + +// Limiter is a sharded map of token buckets keyed by an arbitrary string. +// Buckets age out after `evictAfter` of inactivity so the map doesn't grow +// unbounded. Eviction is amortized — the full O(N) scan runs at most once +// per `evictAfter/2` so a single hot-path Allow doesn't pay the cost. +// When the map exceeds maxBuckets, new keys collapse onto a shared +// overflow bucket; the spray still gets rate-limited (just not per-key) +// and memory stays bounded. +type Limiter struct { + mu sync.Mutex + buckets map[string]*entry + overflow *rate.Limiter + maxBuckets int + rate rate.Limit + burst int + evictAfter time.Duration + lastEviction time.Time + evictInterval time.Duration +} + +type entry struct { + limiter *rate.Limiter + lastSeen time.Time +} + +// New returns a Limiter that allows up to `burst` events per key over `per`, +// with steady-state refill at `burst/per`. evictAfter is the inactivity +// window after which a key's bucket is forgotten — pick something larger +// than `per` so genuine retries don't reset their bucket. +// +// The bucket map is capped at DefaultMaxBuckets entries. Operators that +// need a different cap can construct via NewWithCap. +func New(burst int, per, evictAfter time.Duration) *Limiter { + return NewWithCap(burst, per, evictAfter, DefaultMaxBuckets) +} + +// NewWithCap is New with an explicit ceiling on the per-key map size. +func NewWithCap(burst int, per, evictAfter time.Duration, maxBuckets int) *Limiter { + interval := evictAfter / 2 + if interval <= 0 { + interval = time.Second + } + if maxBuckets <= 0 { + maxBuckets = DefaultMaxBuckets + } + r := rate.Every(per / time.Duration(burst)) + return &Limiter{ + buckets: make(map[string]*entry), + overflow: rate.NewLimiter(r, burst), + maxBuckets: maxBuckets, + rate: r, + burst: burst, + evictAfter: evictAfter, + evictInterval: interval, + } +} + +// Allow returns true if the supplied key can perform one event under the +// current bucket state. Side-effect: the bucket is created on first use +// and idle buckets are GC'd opportunistically (at most once per +// evictInterval to keep the hot path constant-time). When the map is +// already at maxBuckets and the key has no existing bucket, the call +// falls back to the shared overflow bucket so memory stays bounded. +func (l *Limiter) Allow(key string) bool { + now := time.Now() + l.mu.Lock() + defer l.mu.Unlock() + // Amortized eviction: walk the map only when the throttle says it's + // time. Each Allow is O(1) on the steady-state path. (Cluster-3 + // review item — keeps the lock-held duration bounded under load.) + if now.Sub(l.lastEviction) >= l.evictInterval { + for k, e := range l.buckets { + if now.Sub(e.lastSeen) > l.evictAfter { + delete(l.buckets, k) + } + } + l.lastEviction = now + } + if e, ok := l.buckets[key]; ok { + e.lastSeen = now + return e.limiter.Allow() + } + // New key. If the map is at the cap, route through the shared + // overflow bucket — spray attackers can saturate it, but legitimate + // keys that already have a bucket still get their own quota. + // + if len(l.buckets) >= l.maxBuckets { + return l.overflow.Allow() + } + e := &entry{limiter: rate.NewLimiter(l.rate, l.burst), lastSeen: now} + l.buckets[key] = e + return e.limiter.Allow() +} + +// HTTPMiddleware returns a middleware that rejects requests with 429 when +// `keyFn(r)` exceeds the limit. keyFn is responsible for choosing the +// dimension (e.g., utils.GetIP(r), or `utils.GetIP(r) + ":" + username`). +// +// onReject is invoked synchronously when a request is rejected — use it to +// emit an audit-log entry. May be nil. +func (l *Limiter) HTTPMiddleware(keyFn func(*http.Request) string, onReject func(*http.Request, string)) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := keyFn(r) + if !l.Allow(key) { + if onReject != nil { + onReject(r, key) + } + w.Header().Set("Retry-After", "60") + http.Error(w, "too many requests", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// KeyByIP is a convenience keyFn for IP-based rate limiting. Honors +// X-Real-IP / X-Forwarded-For via utils.GetIP. +func KeyByIP(r *http.Request) string { + return utils.GetIP(r) +} diff --git a/pkg/ratelimit/ratelimit_test.go b/pkg/ratelimit/ratelimit_test.go new file mode 100644 index 00000000..61dfc466 --- /dev/null +++ b/pkg/ratelimit/ratelimit_test.go @@ -0,0 +1,108 @@ +package ratelimit + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +// TestAllowBurst verifies a Limiter allows up to `burst` calls in a single +// window and then refuses the (burst+1)th. +func TestAllowBurst(t *testing.T) { + l := New(3, time.Second, time.Minute) + for i := 0; i < 3; i++ { + if !l.Allow("k") { + t.Fatalf("expected Allow #%d to return true", i+1) + } + } + if l.Allow("k") { + t.Fatal("expected the burst+1 request to be rejected") + } +} + +// TestAllowSeparateKeys verifies buckets don't bleed between keys. +func TestAllowSeparateKeys(t *testing.T) { + l := New(2, time.Second, time.Minute) + l.Allow("a") + l.Allow("a") + if l.Allow("a") { + t.Fatal("key a should be over budget") + } + if !l.Allow("b") { + t.Fatal("key b has its own budget") + } +} + +// TestHTTPMiddleware429s verifies the middleware returns 429 + Retry-After +// when the bucket is empty and calls onReject for telemetry. +func TestHTTPMiddleware429s(t *testing.T) { + l := New(1, time.Second, time.Minute) + rejected := 0 + mw := l.HTTPMiddleware( + func(r *http.Request) string { return "fixed" }, + func(r *http.Request, key string) { rejected++ }, + ) + allowed := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + first := httptest.NewRecorder() + allowed.ServeHTTP(first, httptest.NewRequest("POST", "/login", nil)) + if first.Code != http.StatusOK { + t.Fatalf("first request: got %d, want 200", first.Code) + } + + second := httptest.NewRecorder() + allowed.ServeHTTP(second, httptest.NewRequest("POST", "/login", nil)) + if second.Code != http.StatusTooManyRequests { + t.Fatalf("second request: got %d, want 429", second.Code) + } + if got := second.Header().Get("Retry-After"); got == "" { + t.Fatal("missing Retry-After header on 429") + } + if rejected != 1 { + t.Fatalf("onReject calls: got %d, want 1", rejected) + } +} + +// TestBucketCapOverflow — once `maxBuckets` is reached, additional +// distinct keys all route through the shared overflow bucket so map +// growth is bounded. Existing keys keep their per-key budget. +func TestBucketCapOverflow(t *testing.T) { + // burst=1, per=time.Hour — each per-key bucket allows exactly one + // request before refilling. + l := NewWithCap(1, time.Hour, time.Minute, 2) + + // Two keys → both get their own bucket and one Allow each. + if !l.Allow("k1") { + t.Fatal("k1 first Allow must succeed") + } + if !l.Allow("k2") { + t.Fatal("k2 first Allow must succeed") + } + if l.Allow("k1") { + t.Fatal("k1 second Allow must fail (per-key budget exhausted)") + } + + // k3 / k4 / k5 are NEW keys past the cap. They all share the + // overflow bucket (burst 1). The first one consumes the overflow + // burst; the rest must be denied. + got := 0 + for _, k := range []string{"k3", "k4", "k5", "k6"} { + if l.Allow(k) { + got++ + } + } + if got > 1 { + t.Fatalf("overflow burst must be 1, got %d successful Allows on capped keys", got) + } + + // Verify the map didn't grow past the cap. + l.mu.Lock() + size := len(l.buckets) + l.mu.Unlock() + if size > 2 { + t.Fatalf("bucket map exceeded cap: size=%d, cap=2", size) + } +} diff --git a/pkg/types/types.go b/pkg/types/types.go index c816f395..2536441a 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -1,5 +1,7 @@ package types +import "time" + // OsqueryTable to show tables to query type OsqueryTable struct { Name string `json:"name"` @@ -85,6 +87,7 @@ type ApiLoginRequest struct { // ApiErrorResponse to be returned to API requests with the error message type ApiErrorResponse struct { Error string `json:"error"` + Code string `json:"code,omitempty"` } // ApiQueriesResponse to be returned to API requests for queries @@ -104,7 +107,8 @@ type ApiDataResponse struct { // ApiLoginResponse to be returned to API login requests with the generated token type ApiLoginResponse struct { - Token string `json:"token"` + Token string `json:"token"` + CSRFToken string `json:"csrf_token,omitempty"` } // ApiActionsRequest to receive action requests @@ -155,3 +159,56 @@ type ApiUserRequest struct { API bool `json:"api"` Environments []string `json:"environments"` } + +// TLSEnvironmentView is the low-privilege projection of an environment. +// UserLevel operators (env scope) need basic env metadata so the SPA can +// render its env switcher / dashboard / table chrome — but they MUST NOT +// receive the enroll secret, the certificate, or one-liner URLs that +// embed the secret. The full storage struct is admin-only via +// EnvironmentAdminHandler. +type TLSEnvironmentView struct { + ID uint `json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + UUID string `json:"uuid"` + Name string `json:"name"` + Hostname string `json:"hostname"` + Type string `json:"type"` + Icon string `json:"icon"` + DebugHTTP bool `json:"debug_http"` + ConfigTLS bool `json:"config_tls"` + ConfigInterval int `json:"config_interval"` + LoggingTLS bool `json:"logging_tls"` + LogInterval int `json:"log_interval"` + QueryTLS bool `json:"query_tls"` + QueryInterval int `json:"query_interval"` + CarvesTLS bool `json:"carves_tls"` + AcceptEnrolls bool `json:"accept_enrolls"` + EnrollExpire time.Time `json:"enroll_expire"` + RemoveExpire time.Time `json:"remove_expire"` +} + +// AdminUserView is the PII-minimized projection of an AdminUser for +// the GET /api/v1/users and GET /api/v1/users/{username} endpoints. +// Drops LastIPAddress / LastUserAgent / LastAccess / LastTokenUse: a +// super-admin reading another super-admin's record gets enough to +// manage them (username, email, fullname, admin/service flags, env +// scope) but not the network/timing metadata that helps an attacker +// who later compromises one super-admin profile target the others. +// +// Users querying THEIR OWN record see the metadata they need via the +// pre-existing UserMeResponse from /api/v1/users/me — this view is +// strictly for the cross-user "list / inspect another admin" paths. +type AdminUserView struct { + ID uint `json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Username string `json:"username"` + Email string `json:"email"` + Fullname string `json:"fullname"` + Admin bool `json:"admin"` + Service bool `json:"service"` + UUID string `json:"uuid"` + TokenExpire time.Time `json:"token_expire"` + EnvironmentID uint `json:"environment_id"` +} diff --git a/pkg/users/permissions_test.go b/pkg/users/permissions_test.go index f91370bc..f4507caf 100644 --- a/pkg/users/permissions_test.go +++ b/pkg/users/permissions_test.go @@ -16,7 +16,7 @@ import ( func setupTestManagerForPermissions(t *testing.T) (*UserManager, sqlmock.Sqlmock) { conf := config.YAMLConfigurationJWT{ - JWTSecret: "test", + JWTSecret: "test-secret-must-be-at-least-32-bytes-long", HoursToExpire: 1, } mockDB, mock, err := sqlmock.New() diff --git a/pkg/users/users.go b/pkg/users/users.go index 5bc08716..8bd18989 100644 --- a/pkg/users/users.go +++ b/pkg/users/users.go @@ -54,12 +54,21 @@ type UserManager struct { JWTConfig *config.YAMLConfigurationJWT } +// MinJWTSecretBytes is the minimum acceptable length of the HMAC JWT secret +// (RFC 7518 §3.2 recommends a key at least as wide as the hash output for +// HS256 ⇒ 32 bytes). Generate one with: openssl rand -base64 48 +const MinJWTSecretBytes = 32 + // CreateUserManager to initialize the users struct and tables func CreateUserManager(backend *gorm.DB, jwtconfig *config.YAMLConfigurationJWT) *UserManager { - // Check if JWT is not empty + // JWT secret must be present and long enough for HS256. if jwtconfig.JWTSecret == "" { log.Fatal().Msgf("JWT Secret can not be empty") } + if len(jwtconfig.JWTSecret) < MinJWTSecretBytes { + log.Fatal().Msgf("JWT Secret too short: have %d bytes, need >= %d. Generate one with: openssl rand -base64 48", + len(jwtconfig.JWTSecret), MinJWTSecretBytes) + } u := &UserManager{DB: backend, JWTConfig: jwtconfig} // table admin_users if err := backend.AutoMigrate(&AdminUser{}); err != nil { @@ -72,10 +81,14 @@ func CreateUserManager(backend *gorm.DB, jwtconfig *config.YAMLConfigurationJWT) return u } +// BcryptCost is the bcrypt work factor for password hashing. 12 is the +// 2026 commodity-CPU recommendation; bcrypt.DefaultCost is 10. +const BcryptCost = 12 + // HashTextWithSalt to hash text before store it func (m *UserManager) HashTextWithSalt(text string) (string, error) { saltedBytes := []byte(text) - hashedBytes, err := bcrypt.GenerateFromPassword(saltedBytes, bcrypt.DefaultCost) + hashedBytes, err := bcrypt.GenerateFromPassword(saltedBytes, BcryptCost) if err != nil { return "", err } @@ -88,7 +101,12 @@ func (m *UserManager) HashPasswordWithSalt(password string) (string, error) { return m.HashTextWithSalt(password) } -// CheckLoginCredentials to check provided login credentials by matching hashes +// CheckLoginCredentials matches password hashes and, on a successful +// match, opportunistically re-hashes the password at the current +// BcryptCost when the stored hash is below it. Users created under an +// older cost migrate transparently on their next login. The rehash +// failure is non-fatal — login succeeds even if the rehash write +// fails (next login retries). func (m *UserManager) CheckLoginCredentials(username, password string) (bool, AdminUser) { // Check if we should include service users user, err := m.Get(username) @@ -98,10 +116,21 @@ func (m *UserManager) CheckLoginCredentials(username, password string) (bool, Ad // Check for hash matching p := []byte(password) existing := []byte(user.PassHash) - err = bcrypt.CompareHashAndPassword(existing, p) - if err != nil { + if err := bcrypt.CompareHashAndPassword(existing, p); err != nil { return false, AdminUser{} } + // Successful login — rehash if the stored cost is below current. + if cost, cerr := bcrypt.Cost(existing); cerr == nil && cost < BcryptCost { + if newHash, herr := m.HashPasswordWithSalt(password); herr == nil { + if uerr := m.DB.Model(&user).Update("pass_hash", newHash).Error; uerr != nil { + log.Err(uerr).Msgf("rehash-on-login: failed to persist new pass_hash for %s", username) + } else { + user.PassHash = newHash + } + } else { + log.Err(herr).Msgf("rehash-on-login: bcrypt cost upgrade failed for %s", username) + } + } return true, user } @@ -130,10 +159,16 @@ func (m *UserManager) CreateToken(username, issuer string, expHours int) (string return tokenString, expirationTime, nil } -// CheckToken to verify if a token used is valid +// CheckToken to verify if a token used is valid. +// Pins the signing algorithm to HMAC so an attacker cannot swap to `alg:none` +// or RS256-with-public-key (RS-vs-HS confusion) — defense-in-depth on top of +// the underlying library's own mitigations. func (m *UserManager) CheckToken(jwtSecret, tokenStr string) (TokenClaims, bool) { claims := &TokenClaims{} tkn, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected jwt signing method: %v", token.Header["alg"]) + } return []byte(jwtSecret), nil }) if err != nil { @@ -234,6 +269,18 @@ func (m *UserManager) IsAdmin(username string) bool { return (results > 0) } +// CountAdmins returns the number of active admin (Admin=true) users. +// Used by the permissions API to refuse demoting the last super-admin +// (which would lock the system out — no remaining super-admin = no +// one can promote anyone else). +func (m *UserManager) CountAdmins() (int64, error) { + var results int64 + if err := m.DB.Model(&AdminUser{}).Where("admin = ?", true).Count(&results).Error; err != nil { + return 0, fmt.Errorf("count admins: %w", err) + } + return results, nil +} + // ChangeAdmin to modify the admin setting for a user func (m *UserManager) ChangeAdmin(username string, admin bool) error { user, err := m.Get(username) @@ -327,17 +374,42 @@ func (m *UserManager) UpdateToken(username, token string, exp time.Time) error { return fmt.Errorf("error getting user %w", err) } if token != user.APIToken { - if err := m.DB.Model(&user).Updates( - AdminUser{ - APIToken: token, - TokenExpire: exp, - }).Error; err != nil { + // Rotation also clears CSRFToken so the SPA's old non-HttpOnly + // CSRF cookie value stops matching the server-side binding — + // stops a stale CSRFToken from outliving the JWT it was minted + // alongside. The SPA must re-login (which writes a fresh + // CSRFToken via UpdateMetadata) before mutations work again. + // + if err := m.DB.Model(&user).Updates(map[string]interface{}{ + "api_token": token, + "token_expire": exp, + "csrf_token": "", + }).Error; err != nil { return fmt.Errorf("update %w", err) } } return nil } +// ClearToken empties the user's APIToken and CSRFToken so any existing +// JWT + CSRF cookie pair for them stops validating. Used by DELETE +// /api/v1/users/{username}/token. We use a map-update so the empty +// strings actually land (GORM's struct-Updates skips zero-value fields). +func (m *UserManager) ClearToken(username string) error { + user, err := m.Get(username) + if err != nil { + return fmt.Errorf("error getting user %w", err) + } + if err := m.DB.Model(&user).Updates(map[string]interface{}{ + "api_token": "", + "token_expire": time.Time{}, + "csrf_token": "", + }).Error; err != nil { + return fmt.Errorf("update %w", err) + } + return nil +} + // ChangeEmail for user by username func (m *UserManager) ChangeEmail(username, email string) error { user, err := m.Get(username) diff --git a/pkg/users/users_test.go b/pkg/users/users_test.go index 977a4a6b..771320ff 100644 --- a/pkg/users/users_test.go +++ b/pkg/users/users_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/DATA-DOG/go-sqlmock" + "github.com/golang-jwt/jwt/v4" "github.com/jmpsec/osctrl/pkg/config" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -16,7 +17,7 @@ import ( func setupTestManager(t *testing.T) (*UserManager, sqlmock.Sqlmock) { conf := config.YAMLConfigurationJWT{ - JWTSecret: "test", + JWTSecret: "test-secret-must-be-at-least-32-bytes-long", HoursToExpire: 1, } mockDB, mock, err := sqlmock.New() @@ -72,14 +73,14 @@ func TestHashTextWithSalt(t *testing.T) { manager, _ := setupTestManager(t) hashed, err := manager.HashTextWithSalt("testText") assert.NoError(t, err) - assert.Equal(t, hashed[0:7], "$2a$10$") + assert.Equal(t, hashed[0:7], "$2a$12$") } func TestHashPasswordWithSalt(t *testing.T) { manager, _ := setupTestManager(t) hashed, err := manager.HashPasswordWithSalt("testPassword") assert.NoError(t, err) - assert.Equal(t, hashed[0:7], "$2a$10$") + assert.Equal(t, hashed[0:7], "$2a$12$") } func TestCheckLoginCredentials(t *testing.T) { @@ -105,7 +106,7 @@ func TestCheckLoginCredentials(t *testing.T) { func TestCreateCheckToken(t *testing.T) { manager, _ := setupTestManager(t) conf := config.YAMLConfigurationJWT{ - JWTSecret: "test", + JWTSecret: "test-secret-must-be-at-least-32-bytes-long", } token, tt, err := manager.CreateToken("testUsername", "issuer", 0) assert.NoError(t, err) @@ -117,6 +118,20 @@ func TestCreateCheckToken(t *testing.T) { assert.Equal(t, "testUsername", claims.Username) } +// TestCheckTokenRejectsNoneAlg locks in the key-func's alg-pinning behaviour: +// even if a forged token bypasses the library's own none-mitigation, our +// explicit `*jwt.SigningMethodHMAC` type-assertion refuses it. +func TestCheckTokenRejectsNoneAlg(t *testing.T) { + manager, _ := setupTestManager(t) + // Hand-build a token signed with alg:none. golang-jwt requires + // jwt.UnsafeAllowNoneSignatureType as the key for SignedString to succeed. + tok := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{"username": "attacker"}) + signed, err := tok.SignedString(jwt.UnsafeAllowNoneSignatureType) + assert.NoError(t, err) + _, valid := manager.CheckToken("test-secret-must-be-at-least-32-bytes-long", signed) + assert.False(t, valid, "alg:none tokens must be rejected by the key-func") +} + func TestGetUser(t *testing.T) { manager, mock := setupTestManager(t) mock.ExpectQuery( @@ -387,9 +402,12 @@ func TestUpdateToken(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectBegin() + // UpdateToken now also clears csrf_token alongside api_token / + // token_expire so a stale CSRF cookie can't outlive its session. + // mock.ExpectExec( - regexp.QuoteMeta(`UPDATE "admin_users" SET "updated_at"=$1,"api_token"=$2,"token_expire"=$3 WHERE "admin_users"."deleted_at" IS NULL AND "id" = $4`)). - WithArgs(sqlmock.AnyArg(), "testToken", tt, 1). + regexp.QuoteMeta(`UPDATE "admin_users" SET "api_token"=$1,"csrf_token"=$2,"token_expire"=$3,"updated_at"=$4 WHERE "admin_users"."deleted_at" IS NULL AND "id" = $5`)). + WithArgs("testToken", "", tt, sqlmock.AnyArg(), 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() @@ -430,3 +448,45 @@ func TestGetAllUsers(t *testing.T) { assert.Equal(t, 1, len(users)) } + +// TestUpdateTokenClearsCSRF locks the contract that rotating APIToken +// also clears CSRFToken so a stale CSRF cookie can't outlive its +// session. +func TestUpdateTokenClearsCSRF(t *testing.T) { + manager, mock := setupTestManager(t) + tt := time.Now() + mock.ExpectQuery( + regexp.QuoteMeta(`SELECT * FROM "admin_users" WHERE username = $1 AND "admin_users"."deleted_at" IS NULL ORDER BY "admin_users"."id" LIMIT $2`)). + WithArgs("alice", 1). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + mock.ExpectBegin() + mock.ExpectExec( + regexp.QuoteMeta(`UPDATE "admin_users" SET "api_token"=$1,"csrf_token"=$2,"token_expire"=$3,"updated_at"=$4 WHERE "admin_users"."deleted_at" IS NULL AND "id" = $5`)). + WithArgs("freshtoken", "", tt, sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := manager.UpdateToken("alice", "freshtoken", tt) + assert.NoError(t, err) +} + +// TestClearTokenAlsoClearsCSRF locks the contract that DELETE +// /users/{u}/token wipes both api_token and csrf_token. +func TestClearTokenAlsoClearsCSRF(t *testing.T) { + manager, mock := setupTestManager(t) + mock.ExpectQuery( + regexp.QuoteMeta(`SELECT * FROM "admin_users" WHERE username = $1 AND "admin_users"."deleted_at" IS NULL ORDER BY "admin_users"."id" LIMIT $2`)). + WithArgs("bob", 1). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + mock.ExpectBegin() + mock.ExpectExec( + regexp.QuoteMeta(`UPDATE "admin_users" SET "api_token"=$1,"csrf_token"=$2,"token_expire"=$3,"updated_at"=$4 WHERE "admin_users"."deleted_at" IS NULL AND "id" = $5`)). + WithArgs("", "", sqlmock.AnyArg(), sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := manager.ClearToken("bob") + assert.NoError(t, err) +} diff --git a/pkg/utils/http-utils.go b/pkg/utils/http-utils.go index 41cf8136..1636b5b2 100644 --- a/pkg/utils/http-utils.go +++ b/pkg/utils/http-utils.go @@ -6,10 +6,13 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/http/httputil" "net/url" "strconv" + "strings" + "sync" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -83,6 +86,14 @@ const Authorization string = "Authorization" // OsctrlUserAgent for customized User-Agent const OsctrlUserAgent string = "osctrl-http-client/1.1" +// AcceptsJSON reports whether the request's Accept header signals JSON. +// Used by the auth middleware to choose between 401 JSON (for SPA/XHR +// clients) and 302 redirect (for browser navigation). +func AcceptsJSON(r *http.Request) bool { + accept := r.Header.Get("Accept") + return strings.Contains(strings.ToLower(accept), "application/json") +} + // SendRequest - Helper function to send HTTP requests func SendRequest(reqType, reqURL string, params io.Reader, headers map[string]string) (int, []byte, error) { u, err := url.Parse(reqURL) @@ -146,17 +157,122 @@ func DebugHTTPDump(l *zerolog.Logger, r *http.Request, showBody bool) { l.Log().Msg(DebugHTTP(r, showBody)) } -// GetIP - Helper to get the IP address from a HTTP request +// trustedProxies is the global set of CIDRs whose X-Real-IP / +// X-Forwarded-For headers GetIP is allowed to honor. When empty (the +// safe default), GetIP returns the connection's RemoteAddr IP verbatim +// and ignores any forwarding headers — preventing an anonymous internet +// attacker from rotating headers to defeat rate-limits or poison the +// audit log. Operators wire trusted proxies at startup via +// SetTrustedProxies; once set, GetIP only consults forwarding headers +// when the connecting peer falls inside one of the configured CIDRs. +var ( + trustedProxiesMu sync.RWMutex + trustedProxies []*net.IPNet +) + +// SetTrustedProxies configures the CIDR allowlist for forwarding-header +// trust. Pass an empty slice (or call with no args) to revert to the +// safe-by-default "ignore forwarding headers" posture. Each CIDR string +// must parse via net.ParseCIDR; invalid entries are logged and skipped. +func SetTrustedProxies(cidrs []string) { + parsed := make([]*net.IPNet, 0, len(cidrs)) + for _, c := range cidrs { + c = strings.TrimSpace(c) + if c == "" { + continue + } + _, n, err := net.ParseCIDR(c) + if err != nil { + log.Warn().Str("cidr", c).Err(err).Msg("trusted-proxies: invalid CIDR, skipping") + continue + } + parsed = append(parsed, n) + } + trustedProxiesMu.Lock() + trustedProxies = parsed + trustedProxiesMu.Unlock() +} + +// isFromTrustedProxy reports whether the connecting peer (host portion +// of r.RemoteAddr) sits inside any configured trusted-proxy CIDR. +func isFromTrustedProxy(r *http.Request) bool { + trustedProxiesMu.RLock() + tps := trustedProxies + trustedProxiesMu.RUnlock() + if len(tps) == 0 { + return false + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + for _, n := range tps { + if n.Contains(ip) { + return true + } + } + return false +} + +// remoteIP returns the connecting peer's IP (no port). Falls back to +// RemoteAddr-as-is when SplitHostPort fails (rare; some net/http test +// machinery omits the port). +func remoteIP(r *http.Request) string { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} + +// GetIP returns the client IP for r. When trusted-proxies are configured +// AND r.RemoteAddr's IP is inside one of them, the right-most untrusted +// hop from X-Forwarded-For (or X-Real-IP) is used (per RFC 7239 §5.2 the +// right-most-untrusted is the IP the trusted edge actually saw connect). +// Otherwise the forwarding headers are ignored and the connection's +// RemoteAddr IP is returned. func GetIP(r *http.Request) string { - realIP := r.Header.Get(XRealIP) - if realIP != "" { - return realIP + if !isFromTrustedProxy(r) { + // Default safe path: never trust forwarding headers. + return remoteIP(r) + } + // Trusted-proxy path. Prefer X-Forwarded-For (a comma-list of hops: + // `client, proxy1, proxy2`). Walk right-to-left and return the + // first IP that's NOT itself inside a trusted-proxy CIDR. + if xff := r.Header.Get(XForwardedFor); xff != "" { + hops := strings.Split(xff, ",") + trustedProxiesMu.RLock() + tps := trustedProxies + trustedProxiesMu.RUnlock() + for i := len(hops) - 1; i >= 0; i-- { + hop := strings.TrimSpace(hops[i]) + ip := net.ParseIP(hop) + if ip == nil { + continue + } + isProxy := false + for _, n := range tps { + if n.Contains(ip) { + isProxy = true + break + } + } + if !isProxy { + return hop + } + } } - forwarded := r.Header.Get(XForwardedFor) - if forwarded != "" { - return forwarded + // Fall back to X-Real-IP (set by single-hop edges like nginx with + // `proxy_set_header X-Real-IP $remote_addr;`). + if rip := strings.TrimSpace(r.Header.Get(XRealIP)); rip != "" { + return rip } - return r.RemoteAddr + // Last resort: the trusted proxy's own address. + return remoteIP(r) } // HTTPResponse - Helper to send HTTP response diff --git a/pkg/utils/http-utils_test.go b/pkg/utils/http-utils_test.go index e7013c0e..e844077c 100644 --- a/pkg/utils/http-utils_test.go +++ b/pkg/utils/http-utils_test.go @@ -85,21 +85,31 @@ func TestSendRequest(t *testing.T) { } func TestGetIP(t *testing.T) { + t.Cleanup(func() { SetTrustedProxies(nil) }) + // All three sub-tests run with a trusted-proxy configuration that + // covers the test RemoteAddr (127.0.0.0/8 for httptest defaults + // and the test addresses below). Without trust configured, GetIP + // ignores forwarding headers — that contract is asserted in + // TestGetIPIgnoresHeadersByDefault. + SetTrustedProxies([]string{"127.0.0.0/8"}) t.Run("get ip X-Real-IP header", func(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "https://whatever/server/path", nil) + req.RemoteAddr = "127.0.0.1:1234" // inside trusted CIDR req.Header.Set(XRealIP, "1.2.3.4") ip := GetIP(req) assert.Equal(t, "1.2.3.4", ip) }) t.Run("get ip X-Forwarder-For header", func(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "https://whatever/server/path", nil) + req.RemoteAddr = "127.0.0.1:1234" req.Header.Set(XForwardedFor, "1.2.3.4") ip := GetIP(req) assert.Equal(t, "1.2.3.4", ip) }) t.Run("get ip RemoteAddr", func(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "https://whatever/server/path", nil) - req.Header.Set(XForwardedFor, "") + // No RemoteAddr set and no headers — GetIP falls back to the + // empty value the request was built with. ip := GetIP(req) assert.Equal(t, "", ip) }) @@ -132,3 +142,70 @@ func TestHTTPDownload(t *testing.T) { assert.Equal(t, "123", rr.Header().Get(ContentLength)) }) } + +// TestGetIPIgnoresHeadersByDefault — out-of-the-box GetIP MUST NOT +// consult X-Real-IP / X-Forwarded-For. +func TestGetIPIgnoresHeadersByDefault(t *testing.T) { + SetTrustedProxies(nil) // reset + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "203.0.113.5:12345" + req.Header.Set("X-Real-IP", "99.99.99.99") + req.Header.Set("X-Forwarded-For", "1.2.3.4, 5.6.7.8") + if got := GetIP(req); got != "203.0.113.5" { + t.Errorf("default GetIP: got %q, want %q (forwarding headers must be ignored)", got, "203.0.113.5") + } +} + +// TestGetIPHonorsTrustedProxy — when the connecting peer is inside a +// trusted-proxy CIDR, the right-most untrusted hop from X-Forwarded-For +// becomes the result. +func TestGetIPHonorsTrustedProxy(t *testing.T) { + t.Cleanup(func() { SetTrustedProxies(nil) }) + SetTrustedProxies([]string{"10.0.0.0/8"}) + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.5:12345" // trusted edge + // `client, edge1, edge2` — edge1/edge2 are inside the trusted CIDR, + // so the right-most-untrusted is "203.0.113.5". + req.Header.Set("X-Forwarded-For", "203.0.113.5, 10.0.0.1, 10.0.0.5") + if got := GetIP(req); got != "203.0.113.5" { + t.Errorf("trusted XFF: got %q, want %q", got, "203.0.113.5") + } +} + +// TestGetIPUntrustedPeerIgnoresHeaders — even with trusted proxies set, +// a request coming from OUTSIDE the trusted CIDRs must ignore headers. +func TestGetIPUntrustedPeerIgnoresHeaders(t *testing.T) { + t.Cleanup(func() { SetTrustedProxies(nil) }) + SetTrustedProxies([]string{"10.0.0.0/8"}) + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "203.0.113.5:12345" // NOT in trusted CIDR + req.Header.Set("X-Forwarded-For", "1.2.3.4") + if got := GetIP(req); got != "203.0.113.5" { + t.Errorf("untrusted peer with header: got %q, want %q", got, "203.0.113.5") + } +} + +// TestGetIPTrustedProxyIPv6 — verify IPv6 trusted-proxy match. +func TestGetIPTrustedProxyIPv6(t *testing.T) { + t.Cleanup(func() { SetTrustedProxies(nil) }) + SetTrustedProxies([]string{"fd00::/8"}) + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "[fd00::1]:443" + req.Header.Set("X-Forwarded-For", "2001:db8::1") + if got := GetIP(req); got != "2001:db8::1" { + t.Errorf("trusted IPv6 XFF: got %q, want %q", got, "2001:db8::1") + } +} + +// TestSetTrustedProxiesIgnoresInvalid — bad CIDRs are dropped silently +// rather than panicking; the remaining good ones still apply. +func TestSetTrustedProxiesIgnoresInvalid(t *testing.T) { + t.Cleanup(func() { SetTrustedProxies(nil) }) + SetTrustedProxies([]string{"not-a-cidr", "10.0.0.0/8", "", " "}) + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.1:443" + req.Header.Set("X-Real-IP", "203.0.113.5") + if got := GetIP(req); got != "203.0.113.5" { + t.Errorf("partial CIDR set: got %q, want %q", got, "203.0.113.5") + } +}