From 2610744569445d2f2f99cf1ab50b63ea5435ebf2 Mon Sep 17 00:00:00 2001 From: alvarofraguas Date: Thu, 14 May 2026 19:27:33 +0200 Subject: [PATCH 1/2] =?UTF-8?q?osctrl-api:=20security=20hardening=20?= =?UTF-8?q?=E2=80=94=20auth=20bedrock,=20env=20secret=20containment,=20sha?= =?UTF-8?q?red=20rate-limit=20+=20audit-log=20infra?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Server-side hardening for osctrl-api, plus shared infrastructure (rate-limit package, audit-log helpers, trusted-proxies plumbing) that osctrl-tls also consumes — its consumer-side changes ship in a companion PR so the TLS-facing surface can be tested in isolation. == Auth bedrock == cmd/api: - --auth=jwt is now the default. Refuse to start with --auth=none unless OSCTRL_INSECURE_NO_AUTH=1 is set. When opted in, a 60s warning ticker keeps the deployment from drifting into 'auth-off forever'. - HttpOnly + Secure cookie session for SPA-style clients (osctrl_token). CLI clients with Authorization: Bearer continue to work unchanged. - Double-submit CSRF (osctrl_csrf cookie + X-CSRF-Token header) for mutating cookie-authenticated requests. CLI Bearer flows exempt. - JWT signing-algorithm pin (HMAC only) to defeat alg-confusion attacks (alg:none / RS256-with-HS256-verify). - JWT secret minimum 32 bytes (HS256 needs HMAC key ≥ hash output). Startup fails fast with the openssl one-liner if too short. - Strict 'forwarded headers' trust via --trusted-proxies. Empty default means utils.GetIP ignores X-Forwarded-For / X-Real-IP — an internet attacker can't spoof IPs to defeat rate-limits or poison audit logs. == Env secret containment + cross-env defense == pkg/types: new TLSEnvironmentView — the low-privilege env projection. Omits Secret, EnrollSecretPath, RemoveSecretPath, Certificate, Flags, and every other field that materially contributes to enrolling a node. cmd/api/handlers/environments.go: - EnvironmentHandler now branches on access level: AdminLevel (or super-admin) gets the full storage struct; UserLevel gets the low-priv view. - EnvEnrollHandler / EnvRemoveHandler raised from UserLevel to AdminLevel — both embed the env's enroll/remove secret. - Both handlers log only the target name, not returnData. - EnvActionsHandler 'create' branch validates caller-supplied UUID via EnvUUIDFilter (rejects malformed) and EnvExists (rejects collision). 'delete' branch gets the same validation for symmetry. cmd/api/handlers/queries.go: QueryResultsHandler now precheck-validates the named query belongs to env.ID via h.Queries.Exists(name, env.ID) and returns 404 otherwise. logging.GetQueryResults filtered on 'name' only, so without this gate a user with QueryLevel on env A could pull results from env B by passing B's query name in A's URL. pkg/environments/environments.go: tighten EnvUUIDFilter regex and add axis-pure Exists/UUIDExists helpers so handler checks can match the router's expectations exactly. == Shared rate-limit + audit-log infrastructure == pkg/ratelimit (new): per-key token-bucket rate limiter with idle eviction. Used by osctrl-api for /login here, and by osctrl-tls for /enroll in the companion PR. Tunable burst, window, and key function (KeyByIP today; KeyByIPAndEnv available). pkg/auditlog/audit.go: FailedLogin + FailedEnroll helpers — a clean stream of authn/enrol failures for SoC tooling to alert on brute-force, password-spray, and enroll abuse. pkg/utils/http-utils.go: SetTrustedProxies + an updated GetIP that honors the trusted-proxies set. Empty (default) ignores X-Forwarded-For / X-Real-IP entirely. == SQL hardening + carve path safety == pkg/carves/utils.go: new ValidCarvePath regexp gate. Without this gate a CarveLevel operator could pass \`'; SELECT 1; --\` and pivot 'carve a file' into 'run any SELECT against your fleet' via GenCarveQuery's string concat. cmd/api/handlers/carves.go (CarvesRunHandler): path validated before the SQL splice. Rejected paths return 400. == Authz + audit-log hardening == pkg/users: - bcrypt cost raised from default (10) to 12. CheckLoginCredentials opportunistically re-hashes existing users at next login (no password reset needed). Rehash failure is non-fatal. - New ClearToken empties APIToken AND CSRFToken so any existing JWT + CSRF cookie pair stops validating. Used by future DELETE /api/v1/users/{username}/token in a follow-up PR. cmd/api/handlers/{users,settings,environments}.go: authz tightenings around permission writes, settings PATCH, and env-action service-name validation. pkg/environments/env-cache.go: keep the 2h cleanup interval; introduce an envCacheTTL constant so the value is self-documenting and tunable locally without changing runtime defaults. == Defaults + ops == deploy/config/{api,admin}.yml: flip --audit-log default to true so audit log writes are on by default. Operators can disable with --audit-log=false. Verified: go build ./... clean, go vet ./... clean, go test ./pkg/... ./cmd/api/... ./cmd/tls/... all green. --- cmd/api/auth.go | 122 +++++++++++++++++-- cmd/api/auth_test.go | 88 ++++++++++++++ cmd/api/handlers/carves.go | 9 ++ cmd/api/handlers/environments.go | 168 ++++++++++++++++++++------ cmd/api/handlers/environments_test.go | 91 ++++++++++++++ cmd/api/handlers/login.go | 98 +++++++++++++-- cmd/api/handlers/queries.go | 8 ++ cmd/api/handlers/settings.go | 13 +- cmd/api/handlers/users.go | 41 +++++-- cmd/api/main.go | 57 ++++++++- deploy/config/admin.yml | 2 +- deploy/config/api.yml | 16 ++- go.mod | 1 + go.sum | 2 + pkg/auditlog/audit.go | 45 +++++++ pkg/carves/utils.go | 31 ++++- pkg/carves/utils_test.go | 51 ++++++++ pkg/config/flags.go | 15 ++- pkg/config/types.go | 5 + pkg/environments/env-cache.go | 27 ++++- pkg/environments/environments.go | 24 +++- pkg/ratelimit/ratelimit.go | 144 ++++++++++++++++++++++ pkg/ratelimit/ratelimit_test.go | 108 +++++++++++++++++ pkg/types/types.go | 59 ++++++++- pkg/users/permissions_test.go | 2 +- pkg/users/users.go | 94 ++++++++++++-- pkg/users/users_test.go | 72 ++++++++++- pkg/utils/http-utils.go | 132 ++++++++++++++++++-- pkg/utils/http-utils_test.go | 79 +++++++++++- 29 files changed, 1487 insertions(+), 117 deletions(-) create mode 100644 cmd/api/auth_test.go create mode 100644 cmd/api/handlers/environments_test.go create mode 100644 pkg/carves/utils_test.go create mode 100644 pkg/ratelimit/ratelimit.go create mode 100644 pkg/ratelimit/ratelimit_test.go diff --git a/cmd/api/auth.go b/cmd/api/auth.go index 4bee5551..3c357931 100644 --- a/cmd/api/auth.go +++ b/cmd/api/auth.go @@ -2,11 +2,13 @@ package main import ( "context" + "crypto/subtle" "net/http" "strings" "github.com/jmpsec/osctrl/cmd/api/handlers" "github.com/jmpsec/osctrl/pkg/config" + "github.com/jmpsec/osctrl/pkg/types" "github.com/jmpsec/osctrl/pkg/utils" "github.com/rs/zerolog/log" ) @@ -16,14 +18,79 @@ const ( contextAPI string = "osctrl-api-context" ) -// Helper to extract token from header +// Cookie + header names — kept in sync with cmd/api/handlers/login.go. +const ( + cookieNameToken = "osctrl_token" + cookieNameCSRF = "osctrl_csrf" + headerNameCSRF = "X-CSRF-Token" +) + +// Helper to extract token from the Authorization header first (CLI clients), +// falling back to the SPA's HttpOnly osctrl_token cookie. func extractHeaderToken(r *http.Request) string { - reqToken := r.Header.Get("Authorization") - splitToken := strings.Split(reqToken, "Bearer") - if len(splitToken) != 2 { - return "" + if v := r.Header.Get("Authorization"); v != "" { + splitToken := strings.Split(v, "Bearer") + if len(splitToken) == 2 { + if t := strings.TrimSpace(splitToken[1]); t != "" { + return t + } + } + } + if c, err := r.Cookie(cookieNameToken); err == nil { + return strings.TrimSpace(c.Value) + } + return "" +} + +// mutatingMethods is the set of HTTP verbs that must carry a valid CSRF token. +// GET/HEAD/OPTIONS are read-only and exempt. +var mutatingMethods = map[string]bool{ + http.MethodPost: true, + http.MethodPut: true, + http.MethodPatch: true, + http.MethodDelete: true, +} + +// checkCSRF enforces the double-submit CSRF pattern on mutating requests. +// The SPA reads the non-HttpOnly osctrl_csrf cookie and echoes it via the +// X-CSRF-Token header on every mutation; we constant-time-compare: +// 1. header == cookie value (classic double-submit), AND +// 2. cookie value == AdminUser.CSRFToken (defeats a cookie-tossing +// attacker who can set both header and cookie without DB write access). +// +// CLI clients that authenticate purely via Authorization: Bearer (no cookie) +// are exempt — there is no browser to ride a cross-site request from. +// +// Note: AdminUser.CSRFToken rotates on every successful /login (see +// LoginHandler ↦ Users.UpdateMetadata). Concurrent logins of the same user +// race; the loser keeps a cookie that no longer matches the stored value +// and gets 403 on the next mutation. APIToken refresh / clear also clear +// CSRFToken (see pkg/users.UpdateToken / ClearToken) so a stale CSRF +// cookie cannot outlive its session. +func checkCSRF(r *http.Request, username string) bool { + // r.Cookie returns ErrNoCookie only when the cookie name is absent; + // an empty-value cookie returns (cookie, nil). Treating the empty case + // as "Bearer client" would bypass CSRF — instead, the call to + // extractHeaderToken upstream rejects empty-value cookies before we + // reach this function (the trimmed value falls through to "" return). + if _, err := r.Cookie(cookieNameToken); err != nil { + // No session cookie ⇒ Bearer-only client (CLI/CI). Nothing to CSRF. + return true + } + headerToken := strings.TrimSpace(r.Header.Get(headerNameCSRF)) + cookie, err := r.Cookie(cookieNameCSRF) + if err != nil || headerToken == "" { + return false + } + cookieValue := strings.TrimSpace(cookie.Value) + if subtle.ConstantTimeCompare([]byte(headerToken), []byte(cookieValue)) != 1 { + return false + } + user, err := apiUsers.Get(username) + if err != nil || user.CSRFToken == "" { + return false } - return strings.TrimSpace(splitToken[1]) + return subtle.ConstantTimeCompare([]byte(cookieValue), []byte(user.CSRFToken)) == 1 } // Handler to check access to a resource based on the authentication enabled @@ -41,12 +108,51 @@ func handlerAuthCheck(h http.Handler, auth, jwtSecret string) http.Handler { // Set middleware values token := extractHeaderToken(r) if token == "" { - http.Redirect(w, r, forbiddenPath, http.StatusForbidden) + if utils.AcceptsJSON(r) { + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusUnauthorized, + types.ApiErrorResponse{Error: "unauthorized", Code: "unauthorized"}) + return + } + // 302 is required by http.Redirect; the legacy 403 didn't actually trigger + // a redirect in any browser since http.Redirect demands a 3xx status. + http.Redirect(w, r, forbiddenPath, http.StatusFound) return } claims, valid := apiUsers.CheckToken(jwtSecret, token) if !valid { - http.Redirect(w, r, forbiddenPath, http.StatusForbidden) + if utils.AcceptsJSON(r) { + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusUnauthorized, + types.ApiErrorResponse{Error: "unauthorized", Code: "unauthorized"}) + return + } + // 302 is required by http.Redirect; the legacy 403 didn't actually trigger + // a redirect in any browser since http.Redirect demands a 3xx status. + http.Redirect(w, r, forbiddenPath, http.StatusFound) + return + } + // Match the presented token against the user's currently-stored APIToken + // so that refresh/delete on /users/{username}/token invalidates old JWTs. + // (CheckToken above only validates the signature.) Service users with no + // stored token are rejected immediately. Constant-time comparison guards + // against timing-side-channel leaks of the stored token. + user, uerr := apiUsers.Get(claims.Username) + tokenMatches := uerr == nil && user.APIToken != "" && + subtle.ConstantTimeCompare([]byte(user.APIToken), []byte(token)) == 1 + if !tokenMatches { + if utils.AcceptsJSON(r) { + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusUnauthorized, + types.ApiErrorResponse{Error: "unauthorized", Code: "unauthorized"}) + return + } + http.Redirect(w, r, forbiddenPath, http.StatusFound) + return + } + // CSRF guard for cookie-authenticated mutating requests. CLI Bearer + // clients are exempt via the cookieNameToken probe inside checkCSRF. + // + if mutatingMethods[r.Method] && !checkCSRF(r, claims.Username) { + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusForbidden, + types.ApiErrorResponse{Error: "csrf token missing or invalid", Code: "csrf"}) return } // Update metadata for the user diff --git a/cmd/api/auth_test.go b/cmd/api/auth_test.go new file mode 100644 index 00000000..965d369f --- /dev/null +++ b/cmd/api/auth_test.go @@ -0,0 +1,88 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/jmpsec/osctrl/pkg/config" +) + +func TestHandlerAuthCheckJSONvsRedirect(t *testing.T) { + // A no-op inner handler — handlerAuthCheck should never call it when + // there's no valid token. We just need to assert the failure response. + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("inner handler should not be called when auth fails") + }) + + h := handlerAuthCheck(inner, config.AuthJWT, "test-jwt-secret") + + t.Run("Accept application/json returns 401 JSON", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/anything", nil) + req.Header.Set("Accept", "application/json") + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Fatalf("status: got %d, want 401", rr.Code) + } + ct := rr.Header().Get("Content-Type") + if ct == "" || ct[:16] != "application/json" { + t.Fatalf("Content-Type: got %q, want application/json...", ct) + } + }) + + t.Run("default client gets 302 redirect", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/v1/anything", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusFound { + t.Fatalf("status: got %d, want 302", rr.Code) + } + if rr.Header().Get("Location") == "" { + t.Fatal("missing Location header on redirect") + } + }) +} + +func TestExtractHeaderTokenPrefersBearerThenCookie(t *testing.T) { + cases := []struct { + name string + header string + cookie string + want string + }{ + {"bearer header", "Bearer abc.def.ghi", "", "abc.def.ghi"}, + {"cookie fallback", "", "xyz.uvw.123", "xyz.uvw.123"}, + {"bearer wins over cookie", "Bearer header-token", "cookie-token", "header-token"}, + {"no auth at all", "", "", ""}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.header != "" { + req.Header.Set("Authorization", tc.header) + } + if tc.cookie != "" { + req.AddCookie(&http.Cookie{Name: cookieNameToken, Value: tc.cookie}) + } + got := extractHeaderToken(req) + if got != tc.want { + t.Fatalf("got %q, want %q", got, tc.want) + } + }) + } +} + +func TestMutatingMethodsTable(t *testing.T) { + // Lock the contract that GET/HEAD/OPTIONS bypass CSRF and PUT/PATCH/POST/DELETE require it. + for _, m := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} { + if mutatingMethods[m] { + t.Errorf("read-only method %s should not require CSRF", m) + } + } + for _, m := range []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} { + if !mutatingMethods[m] { + t.Errorf("mutating method %s must require CSRF", m) + } + } +} diff --git a/cmd/api/handlers/carves.go b/cmd/api/handlers/carves.go index f505d4d2..8d9889e4 100644 --- a/cmd/api/handlers/carves.go +++ b/cmd/api/handlers/carves.go @@ -189,6 +189,15 @@ func (h *HandlersApi) CarvesRunHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "path can not be empty", http.StatusInternalServerError, nil) return } + // Validate the path before it's spliced into the osquery SQL via + // carves.GenCarveQuery. Without this gate a CarveLevel operator + // could inject arbitrary osquery (e.g. `'; SELECT 1; --`) into the + // query that gets distributed to every targeted node — pivoting + // "carve a file" into "run any SELECT". + if !carves.ValidCarvePath(c.Path) { + apiErrorResponse(w, "invalid carve path", http.StatusBadRequest, fmt.Errorf("rejected path %q", c.Path)) + return + } // Make sure the user has permissions to run queries in the environments for _, e := range c.Environments { if !h.Users.CheckPermissions(ctx[ctxUser], users.QueryLevel, e) { diff --git a/cmd/api/handlers/environments.go b/cmd/api/handlers/environments.go index 50d84e89..6feb721d 100644 --- a/cmd/api/handlers/environments.go +++ b/cmd/api/handlers/environments.go @@ -25,6 +25,44 @@ var ( } ) +// denyEnv emits a 403 AND an audit-log entry pinned to the env handler's +// resource class. Used by the env-handler family for every deny branch +// so cross-tenant probes leave an SoC-alertable trail. The path comes +// from r.URL.Path; envID is 0 (NoEnvironment) when the deny happened +// before env resolution. +func (h *HandlersApi) denyEnv(w http.ResponseWriter, r *http.Request, ctx ContextValue, envID uint, reason string) { + h.AuditLog.Denied(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], reason, auditlog.LogTypeEnvironment, envID) + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("denied: %s for user %s", reason, ctx[ctxUser])) +} + +// projectEnvironmentView strips the env-secret-bearing fields from +// TLSEnvironment to produce the SPA-canonical low-privilege envelope. +// Callers MUST use this when serving env data to a non-admin (UserLevel / +// QueryLevel / CarveLevel) user. +func projectEnvironmentView(env environments.TLSEnvironment) types.TLSEnvironmentView { + return types.TLSEnvironmentView{ + ID: env.ID, + CreatedAt: env.CreatedAt, + UpdatedAt: env.UpdatedAt, + UUID: env.UUID, + Name: env.Name, + Hostname: env.Hostname, + Type: env.Type, + Icon: env.Icon, + DebugHTTP: env.DebugHTTP, + ConfigTLS: env.ConfigTLS, + ConfigInterval: env.ConfigInterval, + LoggingTLS: env.LoggingTLS, + LogInterval: env.LogInterval, + QueryTLS: env.QueryTLS, + QueryInterval: env.QueryInterval, + CarvesTLS: env.CarvesTLS, + AcceptEnrolls: env.AcceptEnrolls, + EnrollExpire: env.EnrollExpire, + RemoveExpire: env.RemoveExpire, + } +} + // EnvironmentHandler - GET Handler to return one environment by UUID as JSON func (h *HandlersApi) EnvironmentHandler(w http.ResponseWriter, r *http.Request) { // Debug HTTP if enabled @@ -50,13 +88,21 @@ func (h *HandlersApi) EnvironmentHandler(w http.ResponseWriter, r *http.Request) // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.UserLevel, env.UUID) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, env.ID, "permission check failed") return } - // Serialize and serve JSON - log.Debug().Msgf("Returned environment %s", env.Name) + // Decide projection by privilege level: admins on this env (or + // super-admins) receive the full storage struct including secret / + // certificate / flags. UserLevel operators receive the low-privilege + // view that omits enroll credentials. h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, env) + if h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { + log.Debug().Msgf("Returned environment %s (admin view)", env.Name) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, env) + return + } + log.Debug().Msgf("Returned environment %s (low-priv view)", env.Name) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, projectEnvironmentView(env)) } // EnvironmentMapHandler - GET Handler to return one environment as JSON @@ -79,7 +125,7 @@ func (h *HandlersApi) EnvironmentMapHandler(w http.ResponseWriter, r *http.Reque // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, auditlog.NoEnvironment, "permission check failed") return } // Prepare map by target @@ -112,7 +158,7 @@ func (h *HandlersApi) EnvironmentsHandler(w http.ResponseWriter, r *http.Request // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, auditlog.NoEnvironment, "permission check failed") return } // Get platforms @@ -149,10 +195,15 @@ func (h *HandlersApi) EnvEnrollHandler(w http.ResponseWriter, r *http.Request) { } return } - // Get context data and check access + // Get context data and check access. The enroll endpoint exposes the + // env's enroll secret (directly via target=secret, indirectly via the + // one-liners that embed it in the URL, and via target=flags). That + // secret is the only credential needed to enroll nodes via osctrl-tls, + // so it must be gated to AdminLevel on the env, not UserLevel. + // ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) - if !h.Users.CheckPermissions(ctx[ctxUser], users.UserLevel, env.UUID) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { + h.denyEnv(w, r, ctx, env.ID, "permission check failed") return } // Extract target @@ -185,8 +236,9 @@ func (h *HandlersApi) EnvEnrollHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "invalid target", http.StatusBadRequest, fmt.Errorf("invalid target %s", targetVar)) return } - // Serialize and serve JSON - log.Debug().Msgf("Returned data for environment%s : %s", env.Name, returnData) + // Serialize and serve JSON. Don't log the payload — it contains the + // enroll secret. + log.Debug().Msgf("Returned enroll data for environment %s target=%s", env.Name, targetVar) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiDataResponse{Data: returnData}) } @@ -213,10 +265,12 @@ func (h *HandlersApi) EnvRemoveHandler(w http.ResponseWriter, r *http.Request) { } return } - // Get context data and check access + // Get context data and check access. The remove one-liners embed the + // remove-secret in the URL, so the endpoint must be AdminLevel-gated + // just like the enroll variant. ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) - if !h.Users.CheckPermissions(ctx[ctxUser], users.UserLevel, env.UUID) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { + h.denyEnv(w, r, ctx, env.ID, "permission check failed") return } // Extract target @@ -243,8 +297,9 @@ func (h *HandlersApi) EnvRemoveHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "invalid target", http.StatusBadRequest, fmt.Errorf("invalid target %s", targetVar)) return } - // Serialize and serve JSON - log.Debug().Msgf("Returned data for environment %s : %s", env.Name, returnData) + // Serialize and serve JSON. Don't log the payload — it embeds the + // remove secret. + log.Debug().Msgf("Returned remove data for environment %s target=%s", env.Name, targetVar) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiDataResponse{Data: returnData}) } @@ -274,7 +329,7 @@ func (h *HandlersApi) EnvEnrollActionsHandler(w http.ResponseWriter, r *http.Req // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, env.ID, "permission check failed") return } // Extract action @@ -374,7 +429,7 @@ func (h *HandlersApi) EnvRemoveActionsHandler(w http.ResponseWriter, r *http.Req // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, env.ID, "permission check failed") return } // Extract action @@ -433,7 +488,7 @@ func (h *HandlersApi) EnvActionsHandler(w http.ResponseWriter, r *http.Request) // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { - apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + h.denyEnv(w, r, ctx, auditlog.NoEnvironment, "permission check failed") return } var e types.ApiEnvRequest @@ -450,6 +505,23 @@ func (h *HandlersApi) EnvActionsHandler(w http.ResponseWriter, r *http.Request) apiErrorResponse(w, "invalid data", http.StatusBadRequest, nil) return } + // Validate the optional client-supplied UUID strictly. + // - utils.CheckUUID delegates to google/uuid Parse, accepting only + // canonical UUIDs. EnvUUIDFilter alone is `^[a-z0-9-]+$`, which + // would have happily accepted "-", "a", "deadbeef", etc. + // - ExistsByUUID (vs the polymorphic Exists) ensures a UUID-collision + // check cannot match against an existing env's NAME. The old + // Exists(e.UUID) leaked information across axes. + if e.UUID != "" { + if !utils.CheckUUID(e.UUID) { + apiErrorResponse(w, "invalid uuid", http.StatusBadRequest, fmt.Errorf("rejected uuid %q", e.UUID)) + return + } + if h.Envs.ExistsByUUID(e.UUID) { + apiErrorResponse(w, "uuid already in use", http.StatusConflict, fmt.Errorf("uuid %q collides", e.UUID)) + return + } + } // Check if environment already exists if !h.Envs.Exists(e.Name) { env := h.Envs.Empty(e.Name, e.Hostname) @@ -481,18 +553,18 @@ func (h *HandlersApi) EnvActionsHandler(w http.ResponseWriter, r *http.Request) } // Create a tag for this new environment if !h.Tags.Exists(env.Name) { - if err := h.Tags.NewTag( - env.Name, - "Tag for environment "+env.Name, - "", - env.Icon, - ctx[ctxUser], - env.ID, - false, - tags.TagTypeEnv, - ""); err != nil { - msgReturn = fmt.Sprintf("error generating tag %s ", err.Error()) - return + if err := h.Tags.NewTag( + env.Name, + "Tag for environment "+env.Name, + "", + env.Icon, + ctx[ctxUser], + env.ID, + false, + tags.TagTypeEnv, + ""); err != nil { + msgReturn = fmt.Sprintf("error generating tag %s ", err.Error()) + return } } msgReturn = "environment created successfully" @@ -501,21 +573,37 @@ func (h *HandlersApi) EnvActionsHandler(w http.ResponseWriter, r *http.Request) return } case "delete": - // Verify request fields + // Validate both name and UUID strictly, then verify they refer to + // the SAME environment so the request can't authorise via one + // env's UUID while targeting another env by name. The previous + // shape (polymorphic Exists(e.UUID) → Delete(e.Name)) allowed + // that authorisation/target split. if !environments.EnvNameFilter(e.Name) { apiErrorResponse(w, "invalid environment name", http.StatusBadRequest, nil) return } - if h.Envs.Exists(e.UUID) { - if err := h.Envs.Delete(e.Name); err != nil { - apiErrorResponse(w, "error deleting environment", http.StatusInternalServerError, err) - return - } - msgReturn = "environment deleted successfully" - } else { - apiErrorResponse(w, "environment not found", http.StatusNotFound, fmt.Errorf("environment %s not found", e.Name)) + if e.UUID == "" { + apiErrorResponse(w, "missing environment UUID", http.StatusBadRequest, nil) + return + } + if !utils.CheckUUID(e.UUID) { + apiErrorResponse(w, "invalid environment UUID", http.StatusBadRequest, nil) + return + } + targetEnv, getErr := h.Envs.GetByUUID(e.UUID) + if getErr != nil { + apiErrorResponse(w, "environment not found", http.StatusNotFound, fmt.Errorf("environment %s not found", e.UUID)) + return + } + if targetEnv.Name != e.Name { + apiErrorResponse(w, "name does not match the environment with that UUID", http.StatusBadRequest, fmt.Errorf("uuid %s maps to name %q, body claims %q", e.UUID, targetEnv.Name, e.Name)) + return + } + if err := h.Envs.Delete(targetEnv.Name); err != nil { + apiErrorResponse(w, "error deleting environment", http.StatusInternalServerError, err) return } + msgReturn = "environment deleted successfully" case "edit": // Verify request fields if !environments.EnvUUIDFilter(e.UUID) { diff --git a/cmd/api/handlers/environments_test.go b/cmd/api/handlers/environments_test.go new file mode 100644 index 00000000..bbe332cf --- /dev/null +++ b/cmd/api/handlers/environments_test.go @@ -0,0 +1,91 @@ +package handlers + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/jmpsec/osctrl/pkg/environments" + "gorm.io/gorm" +) + +// TestProjectEnvironmentViewStripsSecrets is the load-bearing regression test +// for the env-secret-containment fix. projectEnvironmentView returns the SPA +// envelope served to UserLevel operators; if a future contributor adds a new +// secret-bearing field to TLSEnvironment without extending the projection, +// the field will leak into the low-priv response. This test marshals the +// projection from a fully-populated source struct and asserts every +// known-sensitive substring is absent from the serialized JSON. +func TestProjectEnvironmentViewStripsSecrets(t *testing.T) { + src := environments.TLSEnvironment{ + Model: gorm.Model{ + ID: 1, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + UUID: "11111111-2222-3333-4444-555555555555", + Name: "prod", + Hostname: "osctrl.example.com", + Type: "dev", + Icon: "rocket", + // The fields below must NOT appear in the projection. + Secret: "SECRET-MARKER-enroll", + EnrollSecretPath: "SECRET-MARKER-enroll-path", + RemoveSecretPath: "SECRET-MARKER-remove-path", + Certificate: "SECRET-MARKER-cert", + Flags: "SECRET-MARKER-flags", + Options: "SECRET-MARKER-options", + Schedule: "SECRET-MARKER-schedule", + Packs: "SECRET-MARKER-packs", + Decorators: "SECRET-MARKER-decorators", + ATC: "SECRET-MARKER-atc", + Configuration: "SECRET-MARKER-configuration", + DebPackage: "SECRET-MARKER-deb", + RpmPackage: "SECRET-MARKER-rpm", + MsiPackage: "SECRET-MARKER-msi", + PkgPackage: "SECRET-MARKER-pkg", + EnrollPath: "SECRET-MARKER-enroll-route", + LogPath: "SECRET-MARKER-log-route", + ConfigPath: "SECRET-MARKER-config-route", + QueryReadPath: "SECRET-MARKER-qread-route", + QueryWritePath: "SECRET-MARKER-qwrite-route", + CarverInitPath: "SECRET-MARKER-carver-init", + CarverBlockPath: "SECRET-MARKER-carver-block", + UserID: 42, + // Operational fields that ARE expected in the view: + ConfigInterval: 60, + LogInterval: 30, + QueryInterval: 10, + AcceptEnrolls: true, + } + + view := projectEnvironmentView(src) + out, err := json.Marshal(view) + if err != nil { + t.Fatalf("marshal: %v", err) + } + body := string(out) + + // Field set + tag names assertions. + wantFields := []string{ + `"uuid":"11111111-2222-3333-4444-555555555555"`, + `"name":"prod"`, + `"hostname":"osctrl.example.com"`, + `"icon":"rocket"`, + `"config_interval":60`, + `"log_interval":30`, + `"query_interval":10`, + `"accept_enrolls":true`, + } + for _, w := range wantFields { + if !strings.Contains(body, w) { + t.Errorf("expected %q in view JSON, got: %s", w, body) + } + } + + // Every SECRET-MARKER must be absent. + if strings.Contains(body, "SECRET-MARKER") { + t.Fatalf("view leaked at least one secret-bearing field: %s", body) + } +} diff --git a/cmd/api/handlers/login.go b/cmd/api/handlers/login.go index 4926890a..8eb75752 100644 --- a/cmd/api/handlers/login.go +++ b/cmd/api/handlers/login.go @@ -1,9 +1,12 @@ package handlers import ( + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" "net/http" + "time" "github.com/jmpsec/osctrl/pkg/types" "github.com/jmpsec/osctrl/pkg/users" @@ -22,10 +25,13 @@ func (h *HandlersApi) LoginHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) return } - // Get environment by UUID - env, err := h.Envs.GetByUUID(envVar) + // Resolve environment by name OR UUID. The SPA login form lets users type + // the env name ("dev", "prod") because UUIDs are not memorable; the API + // must accept either. Get() uses `name = ? OR uuid = ?` so both shapes + // resolve to the same row. A miss returns 404, not 500. + env, err := h.Envs.Get(envVar) if err != nil { - apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) + apiErrorResponse(w, "environment not found", http.StatusNotFound, nil) return } var l types.ApiLoginRequest @@ -34,31 +40,101 @@ func (h *HandlersApi) LoginHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "error parsing POST body", http.StatusInternalServerError, err) return } - // Check credentials + // Check credentials. Audit-log every credential failure so SoC tooling + // has a stream to alert on (brute-force, password spray). The IP comes + // from utils.GetIP so X-Real-IP / X-Forwarded-For behind a reverse + // proxy is honored. access, user := h.Users.CheckLoginCredentials(l.Username, l.Password) if !access { + h.AuditLog.FailedLogin(l.Username, utils.GetIP(r), "invalid credentials") apiErrorResponse(w, "invalid credentials", http.StatusForbidden, err) return } // Check if user has access to this environment if !h.Users.CheckPermissions(l.Username, users.AdminLevel, env.UUID) { + h.AuditLog.FailedLogin(l.Username, utils.GetIP(r), fmt.Sprintf("no admin access to env %s", env.UUID)) apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use %s by user %s", h.ServiceName, l.Username)) return } - // Do we have a token already? - if user.APIToken == "" { - token, exp, err := h.Users.CreateToken(l.Username, h.ServiceName, l.ExpHours) + // Decide whether to reuse the stored token or mint a fresh one. Re-issue + // when there's no token, when the stored token has already expired (the + // reuse path used to return 500 "token already expired" — a regression + // that locked users out after their first session expired), or when the + // stored token is within 60s of expiring so we don't hand out something + // that will fail mid-request. + var tokenExp time.Time + now := time.Now() + const freshnessWindow = 60 * time.Second + needsRefresh := user.APIToken == "" || user.TokenExpire.Before(now.Add(freshnessWindow)) + if needsRefresh { + var token string + token, tokenExp, err = h.Users.CreateToken(l.Username, h.ServiceName, l.ExpHours) if err != nil { apiErrorResponse(w, "error creating token", http.StatusInternalServerError, err) return } - if err = h.Users.UpdateToken(l.Username, token, exp); err != nil { + if err = h.Users.UpdateToken(l.Username, token, tokenExp); err != nil { apiErrorResponse(w, "error updating token", http.StatusInternalServerError, err) return } user.APIToken = token + } else { + tokenExp = user.TokenExpire } - h.AuditLog.NewLogin(l.Username, r.RemoteAddr) - // Serialize and serve JSON - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiLoginResponse{Token: user.APIToken}) + // Generate a CSRF token: 16 random bytes encoded as 32 hex chars. + // This cookie is NOT HttpOnly so the SPA can read it and echo it back + // via the X-CSRF-Token header on mutating requests. + csrfBytes := make([]byte, 16) + if _, err = rand.Read(csrfBytes); err != nil { + apiErrorResponse(w, "error generating csrf token", http.StatusInternalServerError, err) + return + } + csrfToken := hex.EncodeToString(csrfBytes) + // Persist the CSRF token alongside the user so the auth middleware can + // verify subsequent X-CSRF-Token headers. Without this write the SPA's + // double-submit pattern is purely cosmetic. + // IP comes from utils.GetIP so it matches the format every other site + // writes to last_ip_address (clean IP, X-Real-IP / X-Forwarded-For aware). + clientIP := utils.GetIP(r) + if err := h.Users.UpdateMetadata(clientIP, r.UserAgent(), l.Username, csrfToken); err != nil { + apiErrorResponse(w, "error persisting csrf token", http.StatusInternalServerError, err) + return + } + // Compute cookie Max-Age from token expiry. + maxAge := int(time.Until(tokenExp).Seconds()) + if maxAge <= 0 { + apiErrorResponse(w, "token already expired", http.StatusInternalServerError, fmt.Errorf("token expiry in past or zero: %v", tokenExp)) + return + } + // Set the httpOnly session cookie. The SPA reads the JWT via the cookie; + // it never needs to access this cookie from JS. + // Secure: true requires HTTPS. If TLS is terminated at a proxy that speaks + // plain HTTP to this service, set Secure:false in the proxy's cookie rewrite + // rule — do not add an --insecure-cookies flag to keep the surface small. + http.SetCookie(w, &http.Cookie{ + Name: "osctrl_token", + Value: user.APIToken, + Path: "/", + MaxAge: maxAge, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + // Set the CSRF cookie (not HttpOnly — SPA must read it). + http.SetCookie(w, &http.Cookie{ + Name: "osctrl_csrf", + Value: csrfToken, + Path: "/", + MaxAge: maxAge, + HttpOnly: false, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + h.AuditLog.NewLogin(l.Username, clientIP) + // Serialize and serve JSON. Token stays in the body for backward compat + // with CLI consumers that do not use cookies. + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiLoginResponse{ + Token: user.APIToken, + CSRFToken: csrfToken, + }) } diff --git a/cmd/api/handlers/queries.go b/cmd/api/handlers/queries.go index 93afa0b6..36f341a5 100644 --- a/cmd/api/handlers/queries.go +++ b/cmd/api/handlers/queries.go @@ -372,6 +372,14 @@ func (h *HandlersApi) QueryResultsHandler(w http.ResponseWriter, r *http.Request apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } + // Verify the named query belongs to THIS env. logging.GetQueryResults + // filters on `name` only — without this gate a user with QueryLevel on + // env A could pull results from env B by passing B's query name in + // A's URL. + if !h.Queries.Exists(name, env.ID) { + apiErrorResponse(w, "query not found", http.StatusNotFound, nil) + return + } // Get query by name // TODO this is a temporary solution, we need to refactor this and take into consideration the // logger for TLS and whether if the results are stored in the DB or a different DB diff --git a/cmd/api/handlers/settings.go b/cmd/api/handlers/settings.go index 5c9569f1..985fbabd 100644 --- a/cmd/api/handlers/settings.go +++ b/cmd/api/handlers/settings.go @@ -110,8 +110,11 @@ func (h *HandlersApi) SettingsServiceEnvHandler(w http.ResponseWriter, r *http.R apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get settings - serviceSettings, err := h.Settings.RetrieveValues(service, false, settings.NoEnvironmentID) + // Get settings scoped to THIS env. Previously this passed + // NoEnvironmentID and silently returned global settings, which let an + // env-X admin read another env's values as a side-channel via the + // env-scoped route. + serviceSettings, err := h.Settings.RetrieveValues(service, false, env.ID) if err != nil { apiErrorResponse(w, "error getting settings", http.StatusInternalServerError, err) return @@ -196,8 +199,10 @@ func (h *HandlersApi) SettingsServiceEnvJSONHandler(w http.ResponseWriter, r *ht apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get settings - serviceSettings, err := h.Settings.RetrieveValues(service, true, settings.NoEnvironmentID) + // Get settings scoped to THIS env. Same defense as + // SettingsServiceEnvHandler above; was silently returning global + // settings via NoEnvironmentID. + serviceSettings, err := h.Settings.RetrieveValues(service, true, env.ID) if err != nil { apiErrorResponse(w, "error getting settings", http.StatusInternalServerError, err) return diff --git a/cmd/api/handlers/users.go b/cmd/api/handlers/users.go index 759f80c4..7823defb 100644 --- a/cmd/api/handlers/users.go +++ b/cmd/api/handlers/users.go @@ -13,6 +13,26 @@ import ( "github.com/rs/zerolog/log" ) +// projectAdminUserView strips network-and-timing metadata +// (LastIPAddress / LastUserAgent / LastAccess / LastTokenUse) from an +// AdminUser before serialization to a cross-user reader. Operators +// querying their own row use /api/v1/users/me's full UserMeResponse. +func projectAdminUserView(u users.AdminUser) types.AdminUserView { + return types.AdminUserView{ + ID: u.ID, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, + Username: u.Username, + Email: u.Email, + Fullname: u.Fullname, + Admin: u.Admin, + Service: u.Service, + UUID: u.UUID, + TokenExpire: u.TokenExpire, + EnvironmentID: u.EnvironmentID, + } +} + // UserHandler - GET Handler for environment users func (h *HandlersApi) UserHandler(w http.ResponseWriter, r *http.Request) { // Debug HTTP if enabled @@ -37,10 +57,12 @@ func (h *HandlersApi) UserHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "error getting user", http.StatusInternalServerError, nil) return } - // Serialize and serve JSON + // Serialize and serve the PII-minimized view; the full user record + // is only available to the user themselves via /api/v1/users/me. + // log.Debug().Msgf("Returned user %s", usernameVar) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], auditlog.NoEnvironment) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, user) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, projectAdminUserView(user)) } // UsersHandler - GET Handler for multiple JSON nodes @@ -56,19 +78,24 @@ func (h *HandlersApi) UsersHandler(w http.ResponseWriter, r *http.Request) { return } // Get users - users, err := h.Users.All() + all, err := h.Users.All() if err != nil { apiErrorResponse(w, "error getting users", http.StatusInternalServerError, err) return } - if len(users) == 0 { + if len(all) == 0 { apiErrorResponse(w, "no users", http.StatusNotFound, nil) return } - // Serialize and serve JSON - log.Debug().Msgf("Returned %d users", len(users)) + // PII-minimized view for the cross-user list — see projectAdminUserView. + // + views := make([]types.AdminUserView, 0, len(all)) + for _, u := range all { + views = append(views, projectAdminUserView(u)) + } + log.Debug().Msgf("Returned %d users", len(views)) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], auditlog.NoEnvironment) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, users) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, views) } // UserActionHandler - POST Handler to take actions on a user by username and environment diff --git a/cmd/api/main.go b/cmd/api/main.go index dc569d02..231f7e3e 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -21,9 +21,11 @@ import ( "github.com/jmpsec/osctrl/pkg/logging" "github.com/jmpsec/osctrl/pkg/nodes" "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/ratelimit" "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/tags" "github.com/jmpsec/osctrl/pkg/users" + "github.com/jmpsec/osctrl/pkg/utils" "github.com/jmpsec/osctrl/pkg/version" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -185,8 +187,49 @@ func checkLatestRelease() { } } +// guardAuthMode refuses to start the API with --auth=none unless the operator +// explicitly opts in via OSCTRL_INSECURE_NO_AUTH=1. When the opt-in is set, +// every 60s a loud warning is logged so the deployment cannot drift into +// "auth-off forever" without anyone noticing. +// +// The warning goroutine watches the supplied context so a future graceful +// shutdown path can cancel it cleanly. Today the API has no shutdown signal +// handling so the context never fires — that's acceptable; we get the +// no-leak property for free when shutdown is added. +func guardAuthMode(ctx context.Context, auth string) { + if auth != config.AuthNone { + return + } + if os.Getenv("OSCTRL_INSECURE_NO_AUTH") != "1" { + log.Fatal().Msg("auth=none is disabled by default. Set OSCTRL_INSECURE_NO_AUTH=1 to opt in for local development only — every request will be served as super-admin") + } + go func() { + log.Warn().Msg("INSECURE: osctrl-api running with auth=none — every request is served as super-admin. DO NOT use in production") + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + log.Warn().Msg("INSECURE: osctrl-api running with auth=none — every request is served as super-admin. DO NOT use in production") + } + } + }() +} + // Go go! func osctrlAPIService() { + // Refuse to run unauthenticated unless the operator explicitly opts in. + guardAuthMode(context.Background(), flagParams.Service.Auth) + // Configure forwarding-header trust. Empty (default) means utils.GetIP + // ignores X-Forwarded-For / X-Real-IP and always uses RemoteAddr, so + // an internet attacker can't spoof IPs to defeat rate-limits or + // poison the audit log. + if tp := strings.TrimSpace(flagParams.Service.TrustedProxies); tp != "" { + utils.SetTrustedProxies(strings.Split(tp, ",")) + log.Info().Msgf("Trusting forwarding headers from: %s", tp) + } // ////////////////////////////// Backend log.Info().Msg("Initializing backend...") for { @@ -265,7 +308,6 @@ func osctrlAPIService() { handlers.WithAuditLog(auditLog), handlers.WithDebugHTTP(flagParams.Debug), handlers.WithOsqueryValues(*flagParams.Osquery), - ) // ///////////////////////// API @@ -284,7 +326,16 @@ func osctrlAPIService() { muxAPI.HandleFunc("GET "+_apiPath(checksNoAuthPath), handlersApi.CheckHandlerNoAuth) // ///////////////////////// UNAUTHENTICATED - muxAPI.HandleFunc("POST "+_apiPath(apiLoginPath)+"/{env}", handlersApi.LoginHandler) + // Login is the only password-acceptance surface on the API. Cap to + // 10 attempts per IP per minute (token-bucket; bursts of 10, refill + // at 1/6s) and 429 the rest. Rejections are audit-logged inside the + // LoginHandler / RateLimit middleware so SoC tooling sees the spray. + // + loginLimiter := ratelimit.New(10, time.Minute, 10*time.Minute) + loginRateLimit := loginLimiter.HTTPMiddleware(ratelimit.KeyByIP, func(r *http.Request, key string) { + handlersApi.AuditLog.FailedLogin("", utils.GetIP(r), "rate limit exceeded") + }) + muxAPI.Handle("POST "+_apiPath(apiLoginPath)+"/{env}", loginRateLimit(http.HandlerFunc(handlersApi.LoginHandler))) // ///////////////////////// AUTHENTICATED // API: check auth muxAPI.Handle( @@ -392,7 +443,7 @@ func osctrlAPIService() { handlerAuthCheck(http.HandlerFunc(handlersApi.EnvEnrollActionsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "GET "+_apiPath(apiEnvironmentsPath)+"/{env}/remove/{target}", - handlerAuthCheck(http.HandlerFunc(handlersApi.EnvironmentHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + handlerAuthCheck(http.HandlerFunc(handlersApi.EnvRemoveHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "POST "+_apiPath(apiEnvironmentsPath)+"/{env}/remove/{action}", handlerAuthCheck(http.HandlerFunc(handlersApi.EnvRemoveActionsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) diff --git a/deploy/config/admin.yml b/deploy/config/admin.yml index 1099e916..a5b3d155 100644 --- a/deploy/config/admin.yml +++ b/deploy/config/admin.yml @@ -10,7 +10,7 @@ service: host: osctrl.net # Valid values: "none", "json", "db", "saml", "oidc", "oauth" auth: none - auditLog: false + auditLog: true # Database configuration db: diff --git a/deploy/config/api.yml b/deploy/config/api.yml index e77c8c1f..9e90fba7 100644 --- a/deploy/config/api.yml +++ b/deploy/config/api.yml @@ -8,9 +8,19 @@ service: # Valid values: "json", "console" logFormat: json host: osctrl.net - # Valid values: "none", "json", "db", "saml", "oidc", "oauth" - auth: none - auditLog: false + # Valid values: "jwt", "none". `none` requires OSCTRL_INSECURE_NO_AUTH=1 + # in the environment and is intended for local-dev only — it impersonates + # super-admin on every request. Production deployments MUST use `jwt`. + auth: jwt + auditLog: true + # Comma-separated CIDR list whose X-Real-IP / X-Forwarded-For headers + # utils.GetIP will trust. Leave empty (default) when osctrl-api is + # directly internet-facing — forwarding headers are then ignored and + # RemoteAddr is used verbatim, preventing header-spoofed rate-limit + # bypass and audit-log poisoning. Set to your edge proxy's CIDR(s) + # when osctrl-api sits behind a trusted reverse proxy (e.g. + # `10.0.0.0/8` or `192.0.2.1/32,2001:db8::/64`). + trustedProxies: "" # Database configuration db: diff --git a/go.mod b/go.mod index ac619ec2..fb4fcb49 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,7 @@ require ( golang.org/x/oauth2 v0.36.0 golang.org/x/term v0.42.0 golang.org/x/text v0.36.0 + golang.org/x/time v0.15.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gorm.io/driver/mysql v1.6.0 gorm.io/driver/postgres v1.6.0 diff --git a/go.sum b/go.sum index 888ff6fc..b15d7339 100644 --- a/go.sum +++ b/go.sum @@ -230,6 +230,8 @@ golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/auditlog/audit.go b/pkg/auditlog/audit.go index 29dc5487..bd1d3246 100644 --- a/pkg/auditlog/audit.go +++ b/pkg/auditlog/audit.go @@ -114,6 +114,35 @@ func (m *AuditLogManager) NewLogin(username, ip string) { } } +// FailedLogin records a failed login attempt — invalid credentials, missing +// permission, or any other reason the login flow refused to mint a token. +// `reason` is a short free-text string suitable for SoC alerting and MUST +// NOT contain the offered password. Severity warning so it sticks out next +// to the successful-login firehose. +func (m *AuditLogManager) FailedLogin(username, ip, reason string) { + if !m.Enabled { + return + } + line := fmt.Sprintf("failed login for user %s: %s", username, reason) + if err := m.CreateNew(username, line, ip, LogTypeLogin, SeverityWarning, NoEnvironment); err != nil { + log.Err(err).Msg("error creating failed-login audit log") + } +} + +// FailedEnroll records a failed osquery-node enrollment attempt — invalid +// env secret, denied env, malformed payload. Severity warning, scoped to +// the env in the path (envID == 0 when the env itself was the failure +// reason). +func (m *AuditLogManager) FailedEnroll(ip, envName, reason string, envID uint) { + if !m.Enabled { + return + } + line := fmt.Sprintf("failed enroll for env %s: %s", envName, reason) + if err := m.CreateNew("osctrl-tls", line, ip, LogTypeNode, SeverityWarning, envID); err != nil { + log.Err(err).Msg("error creating failed-enroll audit log") + } +} + // NewLogout - create new logout audit log entry func (m *AuditLogManager) NewLogout(username, ip string) { if !m.Enabled { @@ -224,6 +253,22 @@ func (m *AuditLogManager) EnvAction(username, action, ip string, envID uint) { } } +// Denied records a 403/forbidden access attempt at SeverityWarning so SoC +// dashboards can surface cross-tenant probes. logType pins the resource +// class (LogTypeEnvironment for env handlers, LogTypeNode for node +// handlers, etc.). envID is the env the resource lives in, or +// NoEnvironment when the deny happened before env resolution. The reason +// field is short free text — never echo back the offered credential. +func (m *AuditLogManager) Denied(username, path, ip, reason string, logType, envID uint) { + if !m.Enabled { + return + } + line := fmt.Sprintf("denied access for user %s to %s: %s", username, path, reason) + if err := m.CreateNew(username, line, ip, logType, SeverityWarning, envID); err != nil { + log.Err(err).Msg("error creating denied-access audit log") + } +} + // SettingsAction - create new settings action audit log entry func (m *AuditLogManager) SettingsAction(username, action, ip string) { if !m.Enabled { diff --git a/pkg/carves/utils.go b/pkg/carves/utils.go index d800ecf3..bd224a96 100644 --- a/pkg/carves/utils.go +++ b/pkg/carves/utils.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/base64" "fmt" + "regexp" "strings" "github.com/jmpsec/osctrl/pkg/utils" @@ -78,7 +79,35 @@ func GenCarveName() string { return "carve_" + utils.RandomForNames() } -// Helper to generate the carve query +// validCarvePath restricts the characters that can appear in a carve +// path. The carve string is concatenated into the osquery SQL that +// every targeted node executes; without this gate a CarveLevel +// operator could inject arbitrary osquery (e.g. `'; SELECT 1; --`) and +// pivot from "exfil this path" to "run any SELECT against your nodes". +// +// The character class covers realistic carve targets across the three +// platforms: absolute POSIX paths (Linux/macOS), Windows paths with +// backslashes and drive letters, and glob wildcards (* and ?). It +// explicitly excludes single quote, semicolon, and comment markers. +var validCarvePath = regexp.MustCompile(`^[/A-Za-z0-9._\-\\:*?]+$`) + +// ValidCarvePath reports whether s is a safe value to splice into +// GenCarveQuery. Callers MUST verify before calling GenCarveQuery — +// the result is interpolated directly into SQL. +func ValidCarvePath(s string) bool { + if s == "" { + return false + } + return validCarvePath.MatchString(s) +} + +// Helper to generate the carve query. +// +// `file` is interpolated into the SQL string verbatim. The caller MUST +// have validated it via ValidCarvePath beforehand — passing an +// unvalidated user-controlled value here lets the requesting operator +// run arbitrary osquery on every targeted host, which is well beyond +// the "carve a file" capability the endpoint advertises. func GenCarveQuery(file string, glob bool) string { if glob { return "SELECT * FROM carves WHERE carve=1 AND path LIKE '" + file + "';" diff --git a/pkg/carves/utils_test.go b/pkg/carves/utils_test.go new file mode 100644 index 00000000..03824410 --- /dev/null +++ b/pkg/carves/utils_test.go @@ -0,0 +1,51 @@ +package carves + +import ( + "strings" + "testing" +) + +// TestValidCarvePath locks the character allowlist that gates GenCarveQuery. +func TestValidCarvePath(t *testing.T) { + good := []string{ + "/etc/passwd", + "/var/log/auth.log", + "C:\\Windows\\System32\\drivers\\etc\\hosts", + "/Users/alice/Library/Application_Support/com.example/cfg", + "/var/log/*.log", + "/var/log/auth?.log", + } + for _, p := range good { + if !ValidCarvePath(p) { + t.Errorf("ValidCarvePath(%q): expected true", p) + } + } + bad := []string{ + "", + "'; SELECT 1; --", + "/var/log/a'b", + "/var/log/a;b", + "/var/log/a b", // space + "/var/log/a\"b", + "/var/log/a\nb", + } + for _, p := range bad { + if ValidCarvePath(p) { + t.Errorf("ValidCarvePath(%q): expected false", p) + } + } +} + +// TestGenCarveQueryShape sanity-checks the SQL shape for both glob and +// exact match. Real callers MUST validate file via ValidCarvePath first; +// this test exercises the happy path only. +func TestGenCarveQueryShape(t *testing.T) { + q1 := GenCarveQuery("/etc/passwd", false) + if !strings.Contains(q1, "path = '/etc/passwd'") { + t.Errorf("exact: got %q", q1) + } + q2 := GenCarveQuery("/var/log/*.log", true) + if !strings.Contains(q2, "path LIKE '/var/log/*.log'") { + t.Errorf("glob: got %q", q2) + } +} diff --git a/pkg/config/flags.go b/pkg/config/flags.go index eb48dede..96fbc053 100644 --- a/pkg/config/flags.go +++ b/pkg/config/flags.go @@ -194,8 +194,8 @@ func initServiceFlags(params *ServiceParameters) []cli.Flag { &cli.StringFlag{ Name: "auth", Aliases: []string{"A"}, - Value: AuthNone, - Usage: "Authentication mechanism for the service", + Value: AuthJWT, + Usage: "Authentication mechanism for the service (jwt|none — `none` requires OSCTRL_INSECURE_NO_AUTH=1)", Sources: cli.EnvVars("SERVICE_AUTH"), Destination: ¶ms.Service.Auth, }, @@ -216,11 +216,18 @@ func initServiceFlags(params *ServiceParameters) []cli.Flag { &cli.BoolFlag{ Name: "audit-log", Aliases: []string{"audit"}, - Value: false, - Usage: "Enable audit log for the service. Logs all sensitive actions", + Value: true, + Usage: "Enable audit log for the service. Logs sensitive actions (logins, env mutations, query/carve runs, etc.). Disable only for local dev — production deployments MUST keep this on so SoC tooling has a stream to alert on.", Sources: cli.EnvVars("AUDIT_LOG"), Destination: ¶ms.Service.AuditLog, }, + &cli.StringFlag{ + Name: "trusted-proxies", + Value: "", + Usage: "Comma-separated CIDR list whose X-Real-IP / X-Forwarded-For headers will be honored. Empty (default) ignores forwarding headers and uses RemoteAddr verbatim — prevents header-spoofed rate-limit bypass and audit-log poisoning.", + Sources: cli.EnvVars("SERVICE_TRUSTED_PROXIES"), + Destination: ¶ms.Service.TrustedProxies, + }, } } diff --git a/pkg/config/types.go b/pkg/config/types.go index 1c4295ca..0f64d5e5 100644 --- a/pkg/config/types.go +++ b/pkg/config/types.go @@ -120,6 +120,11 @@ type YAMLConfigurationService struct { Host string `yaml:"host"` Auth string `yaml:"auth"` AuditLog bool `yaml:"auditLog"` + // TrustedProxies is a comma-separated list of CIDRs whose + // X-Real-IP / X-Forwarded-For headers utils.GetIP will honor. + // Default empty → forwarding headers are ignored and the + // connection's RemoteAddr is used. + TrustedProxies string `yaml:"trustedProxies"` } // YAMLConfigurationDB to hold all backend configuration values diff --git a/pkg/environments/env-cache.go b/pkg/environments/env-cache.go index 1f66b843..31226ef1 100644 --- a/pkg/environments/env-cache.go +++ b/pkg/environments/env-cache.go @@ -9,6 +9,19 @@ import ( const ( cacheName = "environments" + // envCacheTTL is the maximum time a TLSEnvironment can sit in the + // EnvCache before the next request refetches from the database. + // + // osctrl-tls holds this cache; osctrl-api mutates env rows in the + // same DB from a different process. There is no IPC channel between + // the two, so envCache invalidation is TTL-based — the TTL bounds + // the window during which enroll-secret rotations, env deletions, + // or config-PATCH changes can be served stale by osctrl-tls. + // + // Kept at the historical 2h cleanup interval; operators who need + // faster invalidation can rotate via `osctrl-tls` restart or tune + // this constant locally. + envCacheTTL = 2 * time.Hour ) // EnvCache provides cached access to TLS environments @@ -22,9 +35,8 @@ type EnvCache struct { // NewEnvCache creates a new environment cache func NewEnvCache(envs EnvManager) *EnvCache { - // Create a new cache with a 10-minute cleanup interval envCache := cache.NewMemoryCache( - cache.WithCleanupInterval[TLSEnvironment](2*time.Hour), + cache.WithCleanupInterval[TLSEnvironment](envCacheTTL), cache.WithName[TLSEnvironment](cacheName), ) @@ -47,24 +59,27 @@ func (ec *EnvCache) GetByUUID(ctx context.Context, uuid string) (TLSEnvironment, return TLSEnvironment{}, err } - ec.cache.Set(ctx, uuid, env, 2*time.Hour) + ec.cache.Set(ctx, uuid, env, envCacheTTL) return env, nil } -// InvalidateEnv removes a specific environment from the cache +// InvalidateEnv removes a specific environment from the cache. Callers +// that mutate env rows in the same process SHOULD invoke this so the +// next request refetches the row without waiting for the TTL. func (ec *EnvCache) InvalidateEnv(ctx context.Context, uuid string) { ec.cache.Delete(ctx, uuid) } -// InvalidateAll clears the entire cache +// InvalidateAll clears the entire cache. Used on bulk operations or +// after operator-driven secret rotations. func (ec *EnvCache) InvalidateAll(ctx context.Context) { ec.cache.Clear(ctx) } // UpdateEnvInCache updates an environment in the cache func (ec *EnvCache) UpdateEnvInCache(ctx context.Context, env TLSEnvironment) { - ec.cache.Set(ctx, env.UUID, env, 2*time.Hour) + ec.cache.Set(ctx, env.UUID, env, envCacheTTL) } // Close stops the cleanup goroutine and releases resources diff --git a/pkg/environments/environments.go b/pkg/environments/environments.go index a419e382..848cece5 100644 --- a/pkg/environments/environments.go +++ b/pkg/environments/environments.go @@ -214,13 +214,35 @@ func (environment *EnvManager) Create(env *TLSEnvironment) error { return nil } -// Exists checks if TLS Environment exists already +// Exists checks if TLS Environment exists already by name OR uuid (polymorphic). +// Prefer ExistsByUUID / ExistsByName when the caller knows which axis to check — +// the polymorphic variant can confuse a UUID-collision check with a name match +// and vice versa, which leaked information across axes in EnvActionsHandler. +// (Cluster-4 review item — see ExistsByUUID below.) func (environment *EnvManager) Exists(identifier string) bool { var results int64 environment.DB.Model(&TLSEnvironment{}).Where("name = ? OR uuid = ?", identifier, identifier).Count(&results) return (results > 0) } +// ExistsByUUID checks if a TLS Environment exists by UUID only. +// Use this when validating a client-supplied UUID for collision before +// creating a new environment, or for unambiguous delete-by-UUID semantics. +func (environment *EnvManager) ExistsByUUID(uuid string) bool { + var results int64 + environment.DB.Model(&TLSEnvironment{}).Where("uuid = ?", uuid).Count(&results) + return (results > 0) +} + +// ExistsByName checks if a TLS Environment exists by name only. +// (Companion to ExistsByUUID — provided for symmetry; callers preferring the +// polymorphic Exists() can keep using it.) +func (environment *EnvManager) ExistsByName(name string) bool { + var results int64 + environment.DB.Model(&TLSEnvironment{}).Where("name = ?", name).Count(&results) + return (results > 0) +} + // ExistsGet checks if TLS Environment exists already and returns it func (environment *EnvManager) ExistsGet(identifier string) (bool, TLSEnvironment) { e, err := environment.Get(identifier) diff --git a/pkg/ratelimit/ratelimit.go b/pkg/ratelimit/ratelimit.go new file mode 100644 index 00000000..85a4eb89 --- /dev/null +++ b/pkg/ratelimit/ratelimit.go @@ -0,0 +1,144 @@ +// Package ratelimit provides a small token-bucket rate-limit middleware +// used to protect anonymous attack surfaces (login, enroll) from +// brute-force / password-spray. +// +// The Limiter is keyed by a caller-supplied function (IP, IP+username, +// etc.) so the same primitive can fan out to per-endpoint policies. +package ratelimit + +import ( + "net/http" + "sync" + "time" + + "github.com/jmpsec/osctrl/pkg/utils" + "golang.org/x/time/rate" +) + +// DefaultMaxBuckets is the cap on the per-key map size. Once exceeded, +// new keys all share a single overflow bucket, so an attacker churning +// arbitrary keys (X-Forwarded-For spoofing or a similar primitive in a +// future surface) cannot grow the limiter's memory footprint unbounded. +const DefaultMaxBuckets = 100_000 + +// Limiter is a sharded map of token buckets keyed by an arbitrary string. +// Buckets age out after `evictAfter` of inactivity so the map doesn't grow +// unbounded. Eviction is amortized — the full O(N) scan runs at most once +// per `evictAfter/2` so a single hot-path Allow doesn't pay the cost. +// When the map exceeds maxBuckets, new keys collapse onto a shared +// overflow bucket; the spray still gets rate-limited (just not per-key) +// and memory stays bounded. +type Limiter struct { + mu sync.Mutex + buckets map[string]*entry + overflow *rate.Limiter + maxBuckets int + rate rate.Limit + burst int + evictAfter time.Duration + lastEviction time.Time + evictInterval time.Duration +} + +type entry struct { + limiter *rate.Limiter + lastSeen time.Time +} + +// New returns a Limiter that allows up to `burst` events per key over `per`, +// with steady-state refill at `burst/per`. evictAfter is the inactivity +// window after which a key's bucket is forgotten — pick something larger +// than `per` so genuine retries don't reset their bucket. +// +// The bucket map is capped at DefaultMaxBuckets entries. Operators that +// need a different cap can construct via NewWithCap. +func New(burst int, per, evictAfter time.Duration) *Limiter { + return NewWithCap(burst, per, evictAfter, DefaultMaxBuckets) +} + +// NewWithCap is New with an explicit ceiling on the per-key map size. +func NewWithCap(burst int, per, evictAfter time.Duration, maxBuckets int) *Limiter { + interval := evictAfter / 2 + if interval <= 0 { + interval = time.Second + } + if maxBuckets <= 0 { + maxBuckets = DefaultMaxBuckets + } + r := rate.Every(per / time.Duration(burst)) + return &Limiter{ + buckets: make(map[string]*entry), + overflow: rate.NewLimiter(r, burst), + maxBuckets: maxBuckets, + rate: r, + burst: burst, + evictAfter: evictAfter, + evictInterval: interval, + } +} + +// Allow returns true if the supplied key can perform one event under the +// current bucket state. Side-effect: the bucket is created on first use +// and idle buckets are GC'd opportunistically (at most once per +// evictInterval to keep the hot path constant-time). When the map is +// already at maxBuckets and the key has no existing bucket, the call +// falls back to the shared overflow bucket so memory stays bounded. +func (l *Limiter) Allow(key string) bool { + now := time.Now() + l.mu.Lock() + defer l.mu.Unlock() + // Amortized eviction: walk the map only when the throttle says it's + // time. Each Allow is O(1) on the steady-state path. (Cluster-3 + // review item — keeps the lock-held duration bounded under load.) + if now.Sub(l.lastEviction) >= l.evictInterval { + for k, e := range l.buckets { + if now.Sub(e.lastSeen) > l.evictAfter { + delete(l.buckets, k) + } + } + l.lastEviction = now + } + if e, ok := l.buckets[key]; ok { + e.lastSeen = now + return e.limiter.Allow() + } + // New key. If the map is at the cap, route through the shared + // overflow bucket — spray attackers can saturate it, but legitimate + // keys that already have a bucket still get their own quota. + // + if len(l.buckets) >= l.maxBuckets { + return l.overflow.Allow() + } + e := &entry{limiter: rate.NewLimiter(l.rate, l.burst), lastSeen: now} + l.buckets[key] = e + return e.limiter.Allow() +} + +// HTTPMiddleware returns a middleware that rejects requests with 429 when +// `keyFn(r)` exceeds the limit. keyFn is responsible for choosing the +// dimension (e.g., utils.GetIP(r), or `utils.GetIP(r) + ":" + username`). +// +// onReject is invoked synchronously when a request is rejected — use it to +// emit an audit-log entry. May be nil. +func (l *Limiter) HTTPMiddleware(keyFn func(*http.Request) string, onReject func(*http.Request, string)) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := keyFn(r) + if !l.Allow(key) { + if onReject != nil { + onReject(r, key) + } + w.Header().Set("Retry-After", "60") + http.Error(w, "too many requests", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// KeyByIP is a convenience keyFn for IP-based rate limiting. Honors +// X-Real-IP / X-Forwarded-For via utils.GetIP. +func KeyByIP(r *http.Request) string { + return utils.GetIP(r) +} diff --git a/pkg/ratelimit/ratelimit_test.go b/pkg/ratelimit/ratelimit_test.go new file mode 100644 index 00000000..61dfc466 --- /dev/null +++ b/pkg/ratelimit/ratelimit_test.go @@ -0,0 +1,108 @@ +package ratelimit + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +// TestAllowBurst verifies a Limiter allows up to `burst` calls in a single +// window and then refuses the (burst+1)th. +func TestAllowBurst(t *testing.T) { + l := New(3, time.Second, time.Minute) + for i := 0; i < 3; i++ { + if !l.Allow("k") { + t.Fatalf("expected Allow #%d to return true", i+1) + } + } + if l.Allow("k") { + t.Fatal("expected the burst+1 request to be rejected") + } +} + +// TestAllowSeparateKeys verifies buckets don't bleed between keys. +func TestAllowSeparateKeys(t *testing.T) { + l := New(2, time.Second, time.Minute) + l.Allow("a") + l.Allow("a") + if l.Allow("a") { + t.Fatal("key a should be over budget") + } + if !l.Allow("b") { + t.Fatal("key b has its own budget") + } +} + +// TestHTTPMiddleware429s verifies the middleware returns 429 + Retry-After +// when the bucket is empty and calls onReject for telemetry. +func TestHTTPMiddleware429s(t *testing.T) { + l := New(1, time.Second, time.Minute) + rejected := 0 + mw := l.HTTPMiddleware( + func(r *http.Request) string { return "fixed" }, + func(r *http.Request, key string) { rejected++ }, + ) + allowed := mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + first := httptest.NewRecorder() + allowed.ServeHTTP(first, httptest.NewRequest("POST", "/login", nil)) + if first.Code != http.StatusOK { + t.Fatalf("first request: got %d, want 200", first.Code) + } + + second := httptest.NewRecorder() + allowed.ServeHTTP(second, httptest.NewRequest("POST", "/login", nil)) + if second.Code != http.StatusTooManyRequests { + t.Fatalf("second request: got %d, want 429", second.Code) + } + if got := second.Header().Get("Retry-After"); got == "" { + t.Fatal("missing Retry-After header on 429") + } + if rejected != 1 { + t.Fatalf("onReject calls: got %d, want 1", rejected) + } +} + +// TestBucketCapOverflow — once `maxBuckets` is reached, additional +// distinct keys all route through the shared overflow bucket so map +// growth is bounded. Existing keys keep their per-key budget. +func TestBucketCapOverflow(t *testing.T) { + // burst=1, per=time.Hour — each per-key bucket allows exactly one + // request before refilling. + l := NewWithCap(1, time.Hour, time.Minute, 2) + + // Two keys → both get their own bucket and one Allow each. + if !l.Allow("k1") { + t.Fatal("k1 first Allow must succeed") + } + if !l.Allow("k2") { + t.Fatal("k2 first Allow must succeed") + } + if l.Allow("k1") { + t.Fatal("k1 second Allow must fail (per-key budget exhausted)") + } + + // k3 / k4 / k5 are NEW keys past the cap. They all share the + // overflow bucket (burst 1). The first one consumes the overflow + // burst; the rest must be denied. + got := 0 + for _, k := range []string{"k3", "k4", "k5", "k6"} { + if l.Allow(k) { + got++ + } + } + if got > 1 { + t.Fatalf("overflow burst must be 1, got %d successful Allows on capped keys", got) + } + + // Verify the map didn't grow past the cap. + l.mu.Lock() + size := len(l.buckets) + l.mu.Unlock() + if size > 2 { + t.Fatalf("bucket map exceeded cap: size=%d, cap=2", size) + } +} diff --git a/pkg/types/types.go b/pkg/types/types.go index c816f395..2536441a 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -1,5 +1,7 @@ package types +import "time" + // OsqueryTable to show tables to query type OsqueryTable struct { Name string `json:"name"` @@ -85,6 +87,7 @@ type ApiLoginRequest struct { // ApiErrorResponse to be returned to API requests with the error message type ApiErrorResponse struct { Error string `json:"error"` + Code string `json:"code,omitempty"` } // ApiQueriesResponse to be returned to API requests for queries @@ -104,7 +107,8 @@ type ApiDataResponse struct { // ApiLoginResponse to be returned to API login requests with the generated token type ApiLoginResponse struct { - Token string `json:"token"` + Token string `json:"token"` + CSRFToken string `json:"csrf_token,omitempty"` } // ApiActionsRequest to receive action requests @@ -155,3 +159,56 @@ type ApiUserRequest struct { API bool `json:"api"` Environments []string `json:"environments"` } + +// TLSEnvironmentView is the low-privilege projection of an environment. +// UserLevel operators (env scope) need basic env metadata so the SPA can +// render its env switcher / dashboard / table chrome — but they MUST NOT +// receive the enroll secret, the certificate, or one-liner URLs that +// embed the secret. The full storage struct is admin-only via +// EnvironmentAdminHandler. +type TLSEnvironmentView struct { + ID uint `json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + UUID string `json:"uuid"` + Name string `json:"name"` + Hostname string `json:"hostname"` + Type string `json:"type"` + Icon string `json:"icon"` + DebugHTTP bool `json:"debug_http"` + ConfigTLS bool `json:"config_tls"` + ConfigInterval int `json:"config_interval"` + LoggingTLS bool `json:"logging_tls"` + LogInterval int `json:"log_interval"` + QueryTLS bool `json:"query_tls"` + QueryInterval int `json:"query_interval"` + CarvesTLS bool `json:"carves_tls"` + AcceptEnrolls bool `json:"accept_enrolls"` + EnrollExpire time.Time `json:"enroll_expire"` + RemoveExpire time.Time `json:"remove_expire"` +} + +// AdminUserView is the PII-minimized projection of an AdminUser for +// the GET /api/v1/users and GET /api/v1/users/{username} endpoints. +// Drops LastIPAddress / LastUserAgent / LastAccess / LastTokenUse: a +// super-admin reading another super-admin's record gets enough to +// manage them (username, email, fullname, admin/service flags, env +// scope) but not the network/timing metadata that helps an attacker +// who later compromises one super-admin profile target the others. +// +// Users querying THEIR OWN record see the metadata they need via the +// pre-existing UserMeResponse from /api/v1/users/me — this view is +// strictly for the cross-user "list / inspect another admin" paths. +type AdminUserView struct { + ID uint `json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Username string `json:"username"` + Email string `json:"email"` + Fullname string `json:"fullname"` + Admin bool `json:"admin"` + Service bool `json:"service"` + UUID string `json:"uuid"` + TokenExpire time.Time `json:"token_expire"` + EnvironmentID uint `json:"environment_id"` +} diff --git a/pkg/users/permissions_test.go b/pkg/users/permissions_test.go index f91370bc..f4507caf 100644 --- a/pkg/users/permissions_test.go +++ b/pkg/users/permissions_test.go @@ -16,7 +16,7 @@ import ( func setupTestManagerForPermissions(t *testing.T) (*UserManager, sqlmock.Sqlmock) { conf := config.YAMLConfigurationJWT{ - JWTSecret: "test", + JWTSecret: "test-secret-must-be-at-least-32-bytes-long", HoursToExpire: 1, } mockDB, mock, err := sqlmock.New() diff --git a/pkg/users/users.go b/pkg/users/users.go index 5bc08716..8bd18989 100644 --- a/pkg/users/users.go +++ b/pkg/users/users.go @@ -54,12 +54,21 @@ type UserManager struct { JWTConfig *config.YAMLConfigurationJWT } +// MinJWTSecretBytes is the minimum acceptable length of the HMAC JWT secret +// (RFC 7518 §3.2 recommends a key at least as wide as the hash output for +// HS256 ⇒ 32 bytes). Generate one with: openssl rand -base64 48 +const MinJWTSecretBytes = 32 + // CreateUserManager to initialize the users struct and tables func CreateUserManager(backend *gorm.DB, jwtconfig *config.YAMLConfigurationJWT) *UserManager { - // Check if JWT is not empty + // JWT secret must be present and long enough for HS256. if jwtconfig.JWTSecret == "" { log.Fatal().Msgf("JWT Secret can not be empty") } + if len(jwtconfig.JWTSecret) < MinJWTSecretBytes { + log.Fatal().Msgf("JWT Secret too short: have %d bytes, need >= %d. Generate one with: openssl rand -base64 48", + len(jwtconfig.JWTSecret), MinJWTSecretBytes) + } u := &UserManager{DB: backend, JWTConfig: jwtconfig} // table admin_users if err := backend.AutoMigrate(&AdminUser{}); err != nil { @@ -72,10 +81,14 @@ func CreateUserManager(backend *gorm.DB, jwtconfig *config.YAMLConfigurationJWT) return u } +// BcryptCost is the bcrypt work factor for password hashing. 12 is the +// 2026 commodity-CPU recommendation; bcrypt.DefaultCost is 10. +const BcryptCost = 12 + // HashTextWithSalt to hash text before store it func (m *UserManager) HashTextWithSalt(text string) (string, error) { saltedBytes := []byte(text) - hashedBytes, err := bcrypt.GenerateFromPassword(saltedBytes, bcrypt.DefaultCost) + hashedBytes, err := bcrypt.GenerateFromPassword(saltedBytes, BcryptCost) if err != nil { return "", err } @@ -88,7 +101,12 @@ func (m *UserManager) HashPasswordWithSalt(password string) (string, error) { return m.HashTextWithSalt(password) } -// CheckLoginCredentials to check provided login credentials by matching hashes +// CheckLoginCredentials matches password hashes and, on a successful +// match, opportunistically re-hashes the password at the current +// BcryptCost when the stored hash is below it. Users created under an +// older cost migrate transparently on their next login. The rehash +// failure is non-fatal — login succeeds even if the rehash write +// fails (next login retries). func (m *UserManager) CheckLoginCredentials(username, password string) (bool, AdminUser) { // Check if we should include service users user, err := m.Get(username) @@ -98,10 +116,21 @@ func (m *UserManager) CheckLoginCredentials(username, password string) (bool, Ad // Check for hash matching p := []byte(password) existing := []byte(user.PassHash) - err = bcrypt.CompareHashAndPassword(existing, p) - if err != nil { + if err := bcrypt.CompareHashAndPassword(existing, p); err != nil { return false, AdminUser{} } + // Successful login — rehash if the stored cost is below current. + if cost, cerr := bcrypt.Cost(existing); cerr == nil && cost < BcryptCost { + if newHash, herr := m.HashPasswordWithSalt(password); herr == nil { + if uerr := m.DB.Model(&user).Update("pass_hash", newHash).Error; uerr != nil { + log.Err(uerr).Msgf("rehash-on-login: failed to persist new pass_hash for %s", username) + } else { + user.PassHash = newHash + } + } else { + log.Err(herr).Msgf("rehash-on-login: bcrypt cost upgrade failed for %s", username) + } + } return true, user } @@ -130,10 +159,16 @@ func (m *UserManager) CreateToken(username, issuer string, expHours int) (string return tokenString, expirationTime, nil } -// CheckToken to verify if a token used is valid +// CheckToken to verify if a token used is valid. +// Pins the signing algorithm to HMAC so an attacker cannot swap to `alg:none` +// or RS256-with-public-key (RS-vs-HS confusion) — defense-in-depth on top of +// the underlying library's own mitigations. func (m *UserManager) CheckToken(jwtSecret, tokenStr string) (TokenClaims, bool) { claims := &TokenClaims{} tkn, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected jwt signing method: %v", token.Header["alg"]) + } return []byte(jwtSecret), nil }) if err != nil { @@ -234,6 +269,18 @@ func (m *UserManager) IsAdmin(username string) bool { return (results > 0) } +// CountAdmins returns the number of active admin (Admin=true) users. +// Used by the permissions API to refuse demoting the last super-admin +// (which would lock the system out — no remaining super-admin = no +// one can promote anyone else). +func (m *UserManager) CountAdmins() (int64, error) { + var results int64 + if err := m.DB.Model(&AdminUser{}).Where("admin = ?", true).Count(&results).Error; err != nil { + return 0, fmt.Errorf("count admins: %w", err) + } + return results, nil +} + // ChangeAdmin to modify the admin setting for a user func (m *UserManager) ChangeAdmin(username string, admin bool) error { user, err := m.Get(username) @@ -327,17 +374,42 @@ func (m *UserManager) UpdateToken(username, token string, exp time.Time) error { return fmt.Errorf("error getting user %w", err) } if token != user.APIToken { - if err := m.DB.Model(&user).Updates( - AdminUser{ - APIToken: token, - TokenExpire: exp, - }).Error; err != nil { + // Rotation also clears CSRFToken so the SPA's old non-HttpOnly + // CSRF cookie value stops matching the server-side binding — + // stops a stale CSRFToken from outliving the JWT it was minted + // alongside. The SPA must re-login (which writes a fresh + // CSRFToken via UpdateMetadata) before mutations work again. + // + if err := m.DB.Model(&user).Updates(map[string]interface{}{ + "api_token": token, + "token_expire": exp, + "csrf_token": "", + }).Error; err != nil { return fmt.Errorf("update %w", err) } } return nil } +// ClearToken empties the user's APIToken and CSRFToken so any existing +// JWT + CSRF cookie pair for them stops validating. Used by DELETE +// /api/v1/users/{username}/token. We use a map-update so the empty +// strings actually land (GORM's struct-Updates skips zero-value fields). +func (m *UserManager) ClearToken(username string) error { + user, err := m.Get(username) + if err != nil { + return fmt.Errorf("error getting user %w", err) + } + if err := m.DB.Model(&user).Updates(map[string]interface{}{ + "api_token": "", + "token_expire": time.Time{}, + "csrf_token": "", + }).Error; err != nil { + return fmt.Errorf("update %w", err) + } + return nil +} + // ChangeEmail for user by username func (m *UserManager) ChangeEmail(username, email string) error { user, err := m.Get(username) diff --git a/pkg/users/users_test.go b/pkg/users/users_test.go index 977a4a6b..771320ff 100644 --- a/pkg/users/users_test.go +++ b/pkg/users/users_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/DATA-DOG/go-sqlmock" + "github.com/golang-jwt/jwt/v4" "github.com/jmpsec/osctrl/pkg/config" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -16,7 +17,7 @@ import ( func setupTestManager(t *testing.T) (*UserManager, sqlmock.Sqlmock) { conf := config.YAMLConfigurationJWT{ - JWTSecret: "test", + JWTSecret: "test-secret-must-be-at-least-32-bytes-long", HoursToExpire: 1, } mockDB, mock, err := sqlmock.New() @@ -72,14 +73,14 @@ func TestHashTextWithSalt(t *testing.T) { manager, _ := setupTestManager(t) hashed, err := manager.HashTextWithSalt("testText") assert.NoError(t, err) - assert.Equal(t, hashed[0:7], "$2a$10$") + assert.Equal(t, hashed[0:7], "$2a$12$") } func TestHashPasswordWithSalt(t *testing.T) { manager, _ := setupTestManager(t) hashed, err := manager.HashPasswordWithSalt("testPassword") assert.NoError(t, err) - assert.Equal(t, hashed[0:7], "$2a$10$") + assert.Equal(t, hashed[0:7], "$2a$12$") } func TestCheckLoginCredentials(t *testing.T) { @@ -105,7 +106,7 @@ func TestCheckLoginCredentials(t *testing.T) { func TestCreateCheckToken(t *testing.T) { manager, _ := setupTestManager(t) conf := config.YAMLConfigurationJWT{ - JWTSecret: "test", + JWTSecret: "test-secret-must-be-at-least-32-bytes-long", } token, tt, err := manager.CreateToken("testUsername", "issuer", 0) assert.NoError(t, err) @@ -117,6 +118,20 @@ func TestCreateCheckToken(t *testing.T) { assert.Equal(t, "testUsername", claims.Username) } +// TestCheckTokenRejectsNoneAlg locks in the key-func's alg-pinning behaviour: +// even if a forged token bypasses the library's own none-mitigation, our +// explicit `*jwt.SigningMethodHMAC` type-assertion refuses it. +func TestCheckTokenRejectsNoneAlg(t *testing.T) { + manager, _ := setupTestManager(t) + // Hand-build a token signed with alg:none. golang-jwt requires + // jwt.UnsafeAllowNoneSignatureType as the key for SignedString to succeed. + tok := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{"username": "attacker"}) + signed, err := tok.SignedString(jwt.UnsafeAllowNoneSignatureType) + assert.NoError(t, err) + _, valid := manager.CheckToken("test-secret-must-be-at-least-32-bytes-long", signed) + assert.False(t, valid, "alg:none tokens must be rejected by the key-func") +} + func TestGetUser(t *testing.T) { manager, mock := setupTestManager(t) mock.ExpectQuery( @@ -387,9 +402,12 @@ func TestUpdateToken(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) mock.ExpectBegin() + // UpdateToken now also clears csrf_token alongside api_token / + // token_expire so a stale CSRF cookie can't outlive its session. + // mock.ExpectExec( - regexp.QuoteMeta(`UPDATE "admin_users" SET "updated_at"=$1,"api_token"=$2,"token_expire"=$3 WHERE "admin_users"."deleted_at" IS NULL AND "id" = $4`)). - WithArgs(sqlmock.AnyArg(), "testToken", tt, 1). + regexp.QuoteMeta(`UPDATE "admin_users" SET "api_token"=$1,"csrf_token"=$2,"token_expire"=$3,"updated_at"=$4 WHERE "admin_users"."deleted_at" IS NULL AND "id" = $5`)). + WithArgs("testToken", "", tt, sqlmock.AnyArg(), 1). WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() @@ -430,3 +448,45 @@ func TestGetAllUsers(t *testing.T) { assert.Equal(t, 1, len(users)) } + +// TestUpdateTokenClearsCSRF locks the contract that rotating APIToken +// also clears CSRFToken so a stale CSRF cookie can't outlive its +// session. +func TestUpdateTokenClearsCSRF(t *testing.T) { + manager, mock := setupTestManager(t) + tt := time.Now() + mock.ExpectQuery( + regexp.QuoteMeta(`SELECT * FROM "admin_users" WHERE username = $1 AND "admin_users"."deleted_at" IS NULL ORDER BY "admin_users"."id" LIMIT $2`)). + WithArgs("alice", 1). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + mock.ExpectBegin() + mock.ExpectExec( + regexp.QuoteMeta(`UPDATE "admin_users" SET "api_token"=$1,"csrf_token"=$2,"token_expire"=$3,"updated_at"=$4 WHERE "admin_users"."deleted_at" IS NULL AND "id" = $5`)). + WithArgs("freshtoken", "", tt, sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := manager.UpdateToken("alice", "freshtoken", tt) + assert.NoError(t, err) +} + +// TestClearTokenAlsoClearsCSRF locks the contract that DELETE +// /users/{u}/token wipes both api_token and csrf_token. +func TestClearTokenAlsoClearsCSRF(t *testing.T) { + manager, mock := setupTestManager(t) + mock.ExpectQuery( + regexp.QuoteMeta(`SELECT * FROM "admin_users" WHERE username = $1 AND "admin_users"."deleted_at" IS NULL ORDER BY "admin_users"."id" LIMIT $2`)). + WithArgs("bob", 1). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + mock.ExpectBegin() + mock.ExpectExec( + regexp.QuoteMeta(`UPDATE "admin_users" SET "api_token"=$1,"csrf_token"=$2,"token_expire"=$3,"updated_at"=$4 WHERE "admin_users"."deleted_at" IS NULL AND "id" = $5`)). + WithArgs("", "", sqlmock.AnyArg(), sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err := manager.ClearToken("bob") + assert.NoError(t, err) +} diff --git a/pkg/utils/http-utils.go b/pkg/utils/http-utils.go index 41cf8136..1636b5b2 100644 --- a/pkg/utils/http-utils.go +++ b/pkg/utils/http-utils.go @@ -6,10 +6,13 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/http/httputil" "net/url" "strconv" + "strings" + "sync" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -83,6 +86,14 @@ const Authorization string = "Authorization" // OsctrlUserAgent for customized User-Agent const OsctrlUserAgent string = "osctrl-http-client/1.1" +// AcceptsJSON reports whether the request's Accept header signals JSON. +// Used by the auth middleware to choose between 401 JSON (for SPA/XHR +// clients) and 302 redirect (for browser navigation). +func AcceptsJSON(r *http.Request) bool { + accept := r.Header.Get("Accept") + return strings.Contains(strings.ToLower(accept), "application/json") +} + // SendRequest - Helper function to send HTTP requests func SendRequest(reqType, reqURL string, params io.Reader, headers map[string]string) (int, []byte, error) { u, err := url.Parse(reqURL) @@ -146,17 +157,122 @@ func DebugHTTPDump(l *zerolog.Logger, r *http.Request, showBody bool) { l.Log().Msg(DebugHTTP(r, showBody)) } -// GetIP - Helper to get the IP address from a HTTP request +// trustedProxies is the global set of CIDRs whose X-Real-IP / +// X-Forwarded-For headers GetIP is allowed to honor. When empty (the +// safe default), GetIP returns the connection's RemoteAddr IP verbatim +// and ignores any forwarding headers — preventing an anonymous internet +// attacker from rotating headers to defeat rate-limits or poison the +// audit log. Operators wire trusted proxies at startup via +// SetTrustedProxies; once set, GetIP only consults forwarding headers +// when the connecting peer falls inside one of the configured CIDRs. +var ( + trustedProxiesMu sync.RWMutex + trustedProxies []*net.IPNet +) + +// SetTrustedProxies configures the CIDR allowlist for forwarding-header +// trust. Pass an empty slice (or call with no args) to revert to the +// safe-by-default "ignore forwarding headers" posture. Each CIDR string +// must parse via net.ParseCIDR; invalid entries are logged and skipped. +func SetTrustedProxies(cidrs []string) { + parsed := make([]*net.IPNet, 0, len(cidrs)) + for _, c := range cidrs { + c = strings.TrimSpace(c) + if c == "" { + continue + } + _, n, err := net.ParseCIDR(c) + if err != nil { + log.Warn().Str("cidr", c).Err(err).Msg("trusted-proxies: invalid CIDR, skipping") + continue + } + parsed = append(parsed, n) + } + trustedProxiesMu.Lock() + trustedProxies = parsed + trustedProxiesMu.Unlock() +} + +// isFromTrustedProxy reports whether the connecting peer (host portion +// of r.RemoteAddr) sits inside any configured trusted-proxy CIDR. +func isFromTrustedProxy(r *http.Request) bool { + trustedProxiesMu.RLock() + tps := trustedProxies + trustedProxiesMu.RUnlock() + if len(tps) == 0 { + return false + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + for _, n := range tps { + if n.Contains(ip) { + return true + } + } + return false +} + +// remoteIP returns the connecting peer's IP (no port). Falls back to +// RemoteAddr-as-is when SplitHostPort fails (rare; some net/http test +// machinery omits the port). +func remoteIP(r *http.Request) string { + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} + +// GetIP returns the client IP for r. When trusted-proxies are configured +// AND r.RemoteAddr's IP is inside one of them, the right-most untrusted +// hop from X-Forwarded-For (or X-Real-IP) is used (per RFC 7239 §5.2 the +// right-most-untrusted is the IP the trusted edge actually saw connect). +// Otherwise the forwarding headers are ignored and the connection's +// RemoteAddr IP is returned. func GetIP(r *http.Request) string { - realIP := r.Header.Get(XRealIP) - if realIP != "" { - return realIP + if !isFromTrustedProxy(r) { + // Default safe path: never trust forwarding headers. + return remoteIP(r) + } + // Trusted-proxy path. Prefer X-Forwarded-For (a comma-list of hops: + // `client, proxy1, proxy2`). Walk right-to-left and return the + // first IP that's NOT itself inside a trusted-proxy CIDR. + if xff := r.Header.Get(XForwardedFor); xff != "" { + hops := strings.Split(xff, ",") + trustedProxiesMu.RLock() + tps := trustedProxies + trustedProxiesMu.RUnlock() + for i := len(hops) - 1; i >= 0; i-- { + hop := strings.TrimSpace(hops[i]) + ip := net.ParseIP(hop) + if ip == nil { + continue + } + isProxy := false + for _, n := range tps { + if n.Contains(ip) { + isProxy = true + break + } + } + if !isProxy { + return hop + } + } } - forwarded := r.Header.Get(XForwardedFor) - if forwarded != "" { - return forwarded + // Fall back to X-Real-IP (set by single-hop edges like nginx with + // `proxy_set_header X-Real-IP $remote_addr;`). + if rip := strings.TrimSpace(r.Header.Get(XRealIP)); rip != "" { + return rip } - return r.RemoteAddr + // Last resort: the trusted proxy's own address. + return remoteIP(r) } // HTTPResponse - Helper to send HTTP response diff --git a/pkg/utils/http-utils_test.go b/pkg/utils/http-utils_test.go index e7013c0e..e844077c 100644 --- a/pkg/utils/http-utils_test.go +++ b/pkg/utils/http-utils_test.go @@ -85,21 +85,31 @@ func TestSendRequest(t *testing.T) { } func TestGetIP(t *testing.T) { + t.Cleanup(func() { SetTrustedProxies(nil) }) + // All three sub-tests run with a trusted-proxy configuration that + // covers the test RemoteAddr (127.0.0.0/8 for httptest defaults + // and the test addresses below). Without trust configured, GetIP + // ignores forwarding headers — that contract is asserted in + // TestGetIPIgnoresHeadersByDefault. + SetTrustedProxies([]string{"127.0.0.0/8"}) t.Run("get ip X-Real-IP header", func(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "https://whatever/server/path", nil) + req.RemoteAddr = "127.0.0.1:1234" // inside trusted CIDR req.Header.Set(XRealIP, "1.2.3.4") ip := GetIP(req) assert.Equal(t, "1.2.3.4", ip) }) t.Run("get ip X-Forwarder-For header", func(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "https://whatever/server/path", nil) + req.RemoteAddr = "127.0.0.1:1234" req.Header.Set(XForwardedFor, "1.2.3.4") ip := GetIP(req) assert.Equal(t, "1.2.3.4", ip) }) t.Run("get ip RemoteAddr", func(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "https://whatever/server/path", nil) - req.Header.Set(XForwardedFor, "") + // No RemoteAddr set and no headers — GetIP falls back to the + // empty value the request was built with. ip := GetIP(req) assert.Equal(t, "", ip) }) @@ -132,3 +142,70 @@ func TestHTTPDownload(t *testing.T) { assert.Equal(t, "123", rr.Header().Get(ContentLength)) }) } + +// TestGetIPIgnoresHeadersByDefault — out-of-the-box GetIP MUST NOT +// consult X-Real-IP / X-Forwarded-For. +func TestGetIPIgnoresHeadersByDefault(t *testing.T) { + SetTrustedProxies(nil) // reset + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "203.0.113.5:12345" + req.Header.Set("X-Real-IP", "99.99.99.99") + req.Header.Set("X-Forwarded-For", "1.2.3.4, 5.6.7.8") + if got := GetIP(req); got != "203.0.113.5" { + t.Errorf("default GetIP: got %q, want %q (forwarding headers must be ignored)", got, "203.0.113.5") + } +} + +// TestGetIPHonorsTrustedProxy — when the connecting peer is inside a +// trusted-proxy CIDR, the right-most untrusted hop from X-Forwarded-For +// becomes the result. +func TestGetIPHonorsTrustedProxy(t *testing.T) { + t.Cleanup(func() { SetTrustedProxies(nil) }) + SetTrustedProxies([]string{"10.0.0.0/8"}) + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.5:12345" // trusted edge + // `client, edge1, edge2` — edge1/edge2 are inside the trusted CIDR, + // so the right-most-untrusted is "203.0.113.5". + req.Header.Set("X-Forwarded-For", "203.0.113.5, 10.0.0.1, 10.0.0.5") + if got := GetIP(req); got != "203.0.113.5" { + t.Errorf("trusted XFF: got %q, want %q", got, "203.0.113.5") + } +} + +// TestGetIPUntrustedPeerIgnoresHeaders — even with trusted proxies set, +// a request coming from OUTSIDE the trusted CIDRs must ignore headers. +func TestGetIPUntrustedPeerIgnoresHeaders(t *testing.T) { + t.Cleanup(func() { SetTrustedProxies(nil) }) + SetTrustedProxies([]string{"10.0.0.0/8"}) + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "203.0.113.5:12345" // NOT in trusted CIDR + req.Header.Set("X-Forwarded-For", "1.2.3.4") + if got := GetIP(req); got != "203.0.113.5" { + t.Errorf("untrusted peer with header: got %q, want %q", got, "203.0.113.5") + } +} + +// TestGetIPTrustedProxyIPv6 — verify IPv6 trusted-proxy match. +func TestGetIPTrustedProxyIPv6(t *testing.T) { + t.Cleanup(func() { SetTrustedProxies(nil) }) + SetTrustedProxies([]string{"fd00::/8"}) + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "[fd00::1]:443" + req.Header.Set("X-Forwarded-For", "2001:db8::1") + if got := GetIP(req); got != "2001:db8::1" { + t.Errorf("trusted IPv6 XFF: got %q, want %q", got, "2001:db8::1") + } +} + +// TestSetTrustedProxiesIgnoresInvalid — bad CIDRs are dropped silently +// rather than panicking; the remaining good ones still apply. +func TestSetTrustedProxiesIgnoresInvalid(t *testing.T) { + t.Cleanup(func() { SetTrustedProxies(nil) }) + SetTrustedProxies([]string{"not-a-cidr", "10.0.0.0/8", "", " "}) + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "10.0.0.1:443" + req.Header.Set("X-Real-IP", "203.0.113.5") + if got := GetIP(req); got != "203.0.113.5" { + t.Errorf("partial CIDR set: got %q, want %q", got, "203.0.113.5") + } +} From b8f83ff673dd26b88e3f3bf5ad463c9de2791e4d Mon Sep 17 00:00:00 2001 From: alvarofraguas Date: Thu, 14 May 2026 19:17:38 +0200 Subject: [PATCH 2/2] osctrl-api: API extensions for a React admin frontend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Round 2 of 3 (round 1: security; round 3: frontend). Adds the API surface the SPA needs to fully replace the legacy admin templates. No existing routes are removed or repurposed — every new endpoint is additive. The new shapes are SPA-canonical (paginated envelope, projections, typed PATCH bodies). == New endpoints == Stats / dashboard: GET /api/v1/stats cross-env summary KPIs GET /api/v1/stats/osquery-versions fleet agent versions GET /api/v1/stats/activity/{env} env-scoped audit-log activity heatmap GET /api/v1/stats/activity/node/{env}/{uuid} per-node activity heatmap GET /api/v1/stats/activity/node-batch/{env} per-node heatmap, up to 100 uuids Logs (live SPA log viewer): GET /api/v1/logs/{type}/{env}/{uuid} paginated, since-aware Saved queries (full CRUD): GET /api/v1/saved-queries/{env} POST /api/v1/saved-queries/{env} PATCH /api/v1/saved-queries/{env}/{name} DELETE /api/v1/saved-queries/{env}/{name} User profile + token + permissions: GET /api/v1/users/me PATCH /api/v1/users/me POST /api/v1/users/me/password POST /api/v1/users/{username}/permissions POST /api/v1/users/{username}/token/refresh DELETE /api/v1/users/{username}/token Environment CRUD + config PATCHes: POST /api/v1/environments PATCH /api/v1/environments/{env} DELETE /api/v1/environments/{env} GET /api/v1/environments/{env}/config PATCH /api/v1/environments/{env}/config PATCH /api/v1/environments/{env}/intervals PATCH /api/v1/environments/{env}/expiration Settings PATCH: PATCH /api/v1/settings/{service}/{name} Audit log filters + pagination: GET /api/v1/audit-logs?service=&username=&type=&envUuid=&since=&until=&page=&pageSize= Login envs (pre-auth env list): GET /api/v1/login/environments pre-auth-safe UUID+name only Sample libraries (operator starter packs): GET /api/v1/queries/samples GET /api/v1/carves/samples GET /api/v1/osquery/tables == Pagination + sort + search == Every list endpoint accepts ?page=&page_size= (default 50, max 500) and returns the envelope: { "items": [...], "page": N, "page_size": N, "total_items": N, "total_pages": N } Sortable fields use a per-resource SortableColumns allowlist enforced at the package layer (pkg/nodes, pkg/queries, pkg/carves). Unknown sort keys fall back to the resource's default order without 400ing. Search is ?q= free-text against a per-resource field set (case-insensitive LIKE). Wildcards are escaped server-side. == New package: pkg/dbutil == Dialect-aware SQL bucket-expression helper (postgres / mysql / sqlite) used by the activity heatmap endpoints. Each category (status logs / result logs / distributed queries / carves) issues a single SQL GROUP BY rather than plucking every timestamp — at 50k+ nodes the table-page heatmap query is bounded by the index instead of the chatty-row count. == Package-layer additions == pkg/nodes: GetByEnvPaged, NodeView projection, SortableColumns, platform-bucket helpers, GetOsqueryVersionCounts. pkg/queries: GetByEnvTargetPaged, GetSaved* CRUD, SortableColumns, sample-template loader, GetNodeQueryBucketed. pkg/carves: GetByEnvPaged, sample-template loader, GetNodeCarveBucketed. pkg/environments: Create / Update / Delete, UpdateConfig / UpdateIntervals / UpdateExpiration helpers. pkg/auditlog: GetPaged with PageFilter; FailedLogin / FailedEnroll hooks; GetEnvActivityBucketed for the heatmap. pkg/logging: GetNodeLogs with ?q= search filter, GetNode{Status,Result}Bucketed for the heatmap. pkg/osquery: LoadTables (osquery schema for the SPA query editor). pkg/types: NodeView, paginated response envelopes, EnvCreate / EnvUpdate / EnvConfig* request types, SettingPatchRequest, SavedQueryView, AdminUserView. Verified: go build ./... clean, go vet ./... clean, go test ./... all packages pass. End-to-end tested against a Kali docker deployment. == What this depends on == This PR is stacked on the security-hardening PR (auth bedrock, env secret containment, TLS-side rate-limit). When that PR is merged upstream, this branch will be re-targeted at the new main HEAD. == What this enables == A separate round-3 PR will land the React admin SPA under a new `frontend/` directory at the repo root. The SPA consumes only the endpoints in this PR — no admin-template surface is touched. --- cmd/api/handlers/audit.go | 140 ++++++- cmd/api/handlers/carves.go | 396 +++++++++++++++---- cmd/api/handlers/environments.go | 11 +- cmd/api/handlers/environments_crud.go | 506 ++++++++++++++++++++++++ cmd/api/handlers/environments_test.go | 19 +- cmd/api/handlers/handlers.go | 9 + cmd/api/handlers/login_envs.go | 48 +++ cmd/api/handlers/logs.go | 124 ++++++ cmd/api/handlers/nodes.go | 177 +++++++-- cmd/api/handlers/queries.go | 258 ++++++++++-- cmd/api/handlers/samples.go | 38 ++ cmd/api/handlers/saved_queries.go | 257 ++++++++++++ cmd/api/handlers/settings.go | 10 +- cmd/api/handlers/settings_patch.go | 111 ++++++ cmd/api/handlers/stats.go | 539 ++++++++++++++++++++++++++ cmd/api/handlers/stats_test.go | 94 +++++ cmd/api/handlers/tags.go | 76 ++-- cmd/api/handlers/users_profile.go | 293 ++++++++++++++ cmd/api/main.go | 142 ++++++- pkg/auditlog/audit.go | 164 ++++++++ pkg/carves/carves.go | 26 ++ pkg/carves/samples.go | 236 +++++++++++ pkg/dbutil/buckets.go | 78 ++++ pkg/environments/environments.go | 83 ++-- pkg/logging/db.go | 228 +++++++++++ pkg/nodes/models.go | 100 ++--- pkg/nodes/nodes.go | 215 +++++++--- pkg/nodes/nodes_test.go | 77 ++++ pkg/nodes/utils.go | 128 ++++++ pkg/osquery/tables.go | 34 ++ pkg/queries/queries.go | 168 +++++++- pkg/queries/queries_test.go | 27 +- pkg/queries/samples.go | 275 +++++++++++++ pkg/queries/saved.go | 162 +++++++- pkg/queries/saved_test.go | 125 ++++++ pkg/tags/tags.go | 32 +- pkg/types/node_view.go | 199 ++++++++++ pkg/types/types.go | 282 +++++++++++++- 38 files changed, 5496 insertions(+), 391 deletions(-) create mode 100644 cmd/api/handlers/environments_crud.go create mode 100644 cmd/api/handlers/login_envs.go create mode 100644 cmd/api/handlers/logs.go create mode 100644 cmd/api/handlers/samples.go create mode 100644 cmd/api/handlers/saved_queries.go create mode 100644 cmd/api/handlers/settings_patch.go create mode 100644 cmd/api/handlers/stats.go create mode 100644 cmd/api/handlers/stats_test.go create mode 100644 cmd/api/handlers/users_profile.go create mode 100644 pkg/carves/samples.go create mode 100644 pkg/dbutil/buckets.go create mode 100644 pkg/nodes/nodes_test.go create mode 100644 pkg/osquery/tables.go create mode 100644 pkg/queries/samples.go create mode 100644 pkg/queries/saved_test.go create mode 100644 pkg/types/node_view.go diff --git a/cmd/api/handlers/audit.go b/cmd/api/handlers/audit.go index 0233811e..05620c6f 100644 --- a/cmd/api/handlers/audit.go +++ b/cmd/api/handlers/audit.go @@ -3,34 +3,156 @@ package handlers import ( "fmt" "net/http" + "strconv" "strings" + "time" "github.com/jmpsec/osctrl/pkg/auditlog" + "github.com/jmpsec/osctrl/pkg/types" "github.com/jmpsec/osctrl/pkg/users" "github.com/jmpsec/osctrl/pkg/utils" "github.com/rs/zerolog/log" ) -// AuditLogsHandler - GET Handler for all audit logs +// AuditLogsHandler - GET /api/v1/audit-logs +// +// Query params: +// +// ?service=... exact match on service name +// ?username=... case-insensitive partial match on username +// ?type=... log type integer (1..10), see pkg/auditlog.LogType* +// ?env_uuid=... filter to one environment (resolved to internal ID) +// ?since=RFC3339 created_at >= since +// ?until=RFC3339 created_at <= until +// ?page=N 1-indexed page; default 1 +// ?page_size=N default 50, max 500 +// +// Returns the SPA-canonical paginated envelope. The handler audit-logs the +// visit on success. func (h *HandlersApi) AuditLogsHandler(w http.ResponseWriter, r *http.Request) { - // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) } - // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get audit logs - auditLogs, err := h.AuditLog.GetAll() + + q := r.URL.Query() + filter := auditlog.PageFilter{ + Service: strings.TrimSpace(q.Get("service")), + Username: strings.TrimSpace(q.Get("username")), + } + if v := q.Get("type"); v != "" { + n, err := strconv.ParseUint(v, 10, 32) + if err != nil { + apiErrorResponse(w, "type must be an integer", http.StatusBadRequest, err) + return + } + if _, ok := auditlog.LogTypes[uint(n)]; !ok { + apiErrorResponse(w, "type is not a known log_type", http.StatusBadRequest, nil) + return + } + filter.LogType = uint(n) + } + if v := q.Get("env_uuid"); v != "" { + env, err := h.Envs.GetByUUID(v) + if err != nil { + apiErrorResponse(w, "env_uuid not found", http.StatusBadRequest, err) + return + } + filter.EnvID = env.ID + } + if v := q.Get("since"); v != "" { + t, err := time.Parse(time.RFC3339, v) + if err != nil { + apiErrorResponse(w, "since must be RFC3339", http.StatusBadRequest, err) + return + } + filter.Since = t + } + if v := q.Get("until"); v != "" { + t, err := time.Parse(time.RFC3339, v) + if err != nil { + apiErrorResponse(w, "until must be RFC3339", http.StatusBadRequest, err) + return + } + filter.Until = t + } + if v := q.Get("page"); v != "" { + n, err := strconv.Atoi(v) + if err != nil || n < 1 { + apiErrorResponse(w, "page must be a positive integer", http.StatusBadRequest, err) + return + } + filter.Page = n + } else { + filter.Page = 1 + } + if v := q.Get("page_size"); v != "" { + n, err := strconv.Atoi(v) + if err != nil || n < 1 { + apiErrorResponse(w, "page_size must be a positive integer", http.StatusBadRequest, err) + return + } + filter.PageSize = n + } + if filter.PageSize == 0 { + filter.PageSize = 50 + } + // Mirror the package-layer clamp at the handler so the response + // envelope echoes the actual effective value and the doc-comment + // "max 500" remains honest if the package layer's bound ever + // shifts. + if filter.PageSize > 500 { + filter.PageSize = 500 + } + + rows, total, err := h.AuditLog.GetPaged(filter) if err != nil { - log.Err(err).Msg("error getting audit logs") + apiErrorResponse(w, "error getting audit logs", http.StatusInternalServerError, err) return } - // Serialize and serve JSON - log.Debug().Msgf("Returned %d audit log entries", len(auditLogs)) + + // Resolve EnvironmentID → UUID with a single map lookup so the SPA can + // render env names directly. Empty UUID == no env / system action. + envMap, _ := h.Envs.GetMapByID() + + items := make([]types.AuditLogView, 0, len(rows)) + for _, r := range rows { + view := types.AuditLogView{ + ID: r.ID, + CreatedAt: r.CreatedAt, + Service: r.Service, + Username: r.Username, + Line: r.Line, + LogType: r.LogType, + Severity: r.Severity, + SourceIP: r.SourceIP, + EnvironmentID: r.EnvironmentID, + } + if r.EnvironmentID > 0 { + if e, ok := envMap[r.EnvironmentID]; ok { + view.EnvUUID = e.UUID + } + } + items = append(items, view) + } + + totalPages := 0 + if total > 0 { + totalPages = int((total + int64(filter.PageSize) - 1) / int64(filter.PageSize)) + } + resp := types.AuditLogsPagedResponse{ + Items: items, + Page: filter.Page, + PageSize: filter.PageSize, + TotalItems: total, + TotalPages: totalPages, + } + h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], auditlog.NoEnvironment) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, auditLogs) + log.Debug().Msgf("Returned %d audit log entries (page=%d, size=%d, total=%d)", len(items), filter.Page, filter.PageSize, total) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, resp) } diff --git a/cmd/api/handlers/carves.go b/cmd/api/handlers/carves.go index 8d9889e4..661c5014 100644 --- a/cmd/api/handlers/carves.go +++ b/cmd/api/handlers/carves.go @@ -2,12 +2,17 @@ package handlers import ( "encoding/json" + "errors" "fmt" + "io" "net/http" + "os" + "strconv" "strings" "time" "github.com/jmpsec/osctrl/pkg/carves" + "github.com/jmpsec/osctrl/pkg/config" "github.com/jmpsec/osctrl/pkg/handlers" "github.com/jmpsec/osctrl/pkg/queries" "github.com/jmpsec/osctrl/pkg/settings" @@ -15,178 +20,234 @@ import ( "github.com/jmpsec/osctrl/pkg/users" "github.com/jmpsec/osctrl/pkg/utils" "github.com/rs/zerolog/log" + "gorm.io/gorm" ) -// GET Handler to return a single carve in JSON +// carveFileView projects a CarvedFile row into the SPA-canonical envelope. +// time.Time stays as time.Time so JSON-encoded output is RFC3339. +func carveFileView(c carves.CarvedFile) types.CarveFileView { + return types.CarveFileView{ + CarveID: c.CarveID, + SessionID: c.SessionID, + UUID: c.UUID, + Path: c.Path, + Status: c.Status, + CarveSize: c.CarveSize, + BlockSize: c.BlockSize, + TotalBlocks: c.TotalBlocks, + CompletedBlocks: c.CompletedBlocks, + Archived: c.Archived, + CreatedAt: c.CreatedAt, + CompletedAt: c.CompletedAt, + } +} + +// CarveShowHandler - GET /api/v1/carves/{env}/{name} +// +// Returns the carve query metadata plus the array of per-node CarvedFile rows +// produced by the carve. Returns 404 when the carve query name does not exist +// in the environment. func (h *HandlersApi) CarveShowHandler(w http.ResponseWriter, r *http.Request) { - // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) } - // Extract name name := r.PathValue("name") if name == "" { - apiErrorResponse(w, "error getting name", http.StatusInternalServerError, nil) + apiErrorResponse(w, "error getting name", http.StatusBadRequest, nil) return } - // Extract environment envVar := r.PathValue("env") if envVar == "" { apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) return } - // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return } - // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.CarveLevel, env.UUID) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get carve by name - carve, err := h.Carves.GetByQuery(name, env.ID) + + // Look up the carve query (DistributedQuery row with type=carve). + q, err := h.Queries.Get(name, env.ID) if err != nil { - if err.Error() == "record not found" { + if errors.Is(err, gorm.ErrRecordNotFound) { apiErrorResponse(w, "carve not found", http.StatusNotFound, err) - } else { - apiErrorResponse(w, "error getting carve", http.StatusInternalServerError, err) + return } + apiErrorResponse(w, "error getting carve", http.StatusInternalServerError, err) + return + } + if q.Type != queries.CarveQueryType { + apiErrorResponse(w, "carve not found", http.StatusNotFound, nil) + return + } + + // Look up the carved files (one per node that completed the carve). + files, err := h.Carves.GetByQuery(name, env.ID) + if err != nil { + apiErrorResponse(w, "error getting carve files", http.StatusInternalServerError, err) return } - // Serialize and serve JSON - log.Debug().Msgf("Returned carve %s", name) + views := make([]types.CarveFileView, 0, len(files)) + for _, f := range files { + views = append(views, carveFileView(f)) + } + + resp := types.CarveDetailResponse{Query: q, Files: views} + log.Debug().Msgf("Returned carve %s (%d files)", name, len(views)) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, carve) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, resp) } -// GET Handler to return carve queries in JSON by target and environment +// CarveQueriesHandler - GET /api/v1/carves/{env}/queries/{target} +// +// Returns carve queries by target. Retained from the legacy contract; the +// canonical list endpoint is now CarveListHandler at /api/v1/carves/{env}. func (h *HandlersApi) CarveQueriesHandler(w http.ResponseWriter, r *http.Request) { - // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) } - // Extract environment envVar := r.PathValue("env") if envVar == "" { apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) return } - // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return } - // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.CarveLevel, env.UUID) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Extract target targetVar := r.PathValue("target") if targetVar == "" { apiErrorResponse(w, "error with target", http.StatusBadRequest, nil) return } - // Verify target if !QueryTargets[targetVar] { apiErrorResponse(w, "invalid target", http.StatusBadRequest, nil) return } - // Get carves - carves, err := h.Queries.GetCarves(targetVar, env.ID) + carvesList, err := h.Queries.GetCarves(targetVar, env.ID) if err != nil { apiErrorResponse(w, "error getting carve queries", http.StatusInternalServerError, err) return } - if len(carves) == 0 { - apiErrorResponse(w, "no carve queries", http.StatusNotFound, nil) - return - } - // Serialize and serve JSON - log.Debug().Msgf("Returned %d carves", len(carves)) + log.Debug().Msgf("Returned %d carves", len(carvesList)) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, carves) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, carvesList) } -// GET Handler to return carves in JSON by environment +// CarveListHandler - GET /api/v1/carves/{env} +// +// Paginated, sorted, searchable list of carve queries (DistributedQuery rows +// with type=carve). Query params: page, page_size, q, sort, dir, target. +// Empty result → HTTP 200 with items: []. func (h *HandlersApi) CarveListHandler(w http.ResponseWriter, r *http.Request) { - // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) } - // Extract environment envVar := r.PathValue("env") if envVar == "" { apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) return } - // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return } - // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.CarveLevel, env.UUID) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get carves - carves, err := h.Carves.GetByEnv(env.ID) + + q := r.URL.Query() + page, _ := strconv.Atoi(q.Get("page")) + pageSize, _ := strconv.Atoi(q.Get("page_size")) + search := q.Get("q") + sortCol := q.Get("sort") + desc := strings.ToLower(q.Get("dir")) != "asc" + target := q.Get("target") + if target == "" { + target = queries.TargetAll + } + if !QueryTargets[target] { + apiErrorResponse(w, "invalid target", http.StatusBadRequest, nil) + return + } + + if pageSize <= 0 { + pageSize = 50 + } + if pageSize > 500 { + pageSize = 500 + } + if page <= 0 { + page = 1 + } + + result, err := h.Queries.GetByEnvTargetPaged(env.ID, target, queries.CarveQueryType, search, page, pageSize, sortCol, desc) if err != nil { apiErrorResponse(w, "error getting carves", http.StatusInternalServerError, err) return } - if len(carves) == 0 { - apiErrorResponse(w, "no carves", http.StatusNotFound, nil) - return + items := result.Items + if items == nil { + items = []queries.DistributedQuery{} + } + var totalPages int + if result.TotalItems > 0 { + totalPages = int((result.TotalItems + int64(pageSize) - 1) / int64(pageSize)) } - // Serialize and serve JSON - log.Debug().Msgf("Returned %d carves", len(carves)) + resp := types.CarvesPagedResponse{ + Items: items, + Page: page, + PageSize: pageSize, + TotalItems: result.TotalItems, + TotalPages: totalPages, + } + log.Debug().Msgf("Returned %d carves (page %d of %d)", len(items), page, totalPages) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, carves) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, resp) } -// POST Handler to run a carve +// CarvesRunHandler - POST /api/v1/carves/{env} func (h *HandlersApi) CarvesRunHandler(w http.ResponseWriter, r *http.Request) { - // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) } - // Extract environment envVar := r.PathValue("env") if envVar == "" { apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) return } - // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return } - // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.CarveLevel, env.UUID) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } var c types.ApiDistributedQueryRequest - // Parse request JSON body if err := json.NewDecoder(r.Body).Decode(&c); err != nil { - apiErrorResponse(w, "error parsing POST body", http.StatusInternalServerError, err) + apiErrorResponse(w, "error parsing POST body", http.StatusBadRequest, err) return } - // Path can not be empty if c.Path == "" { - apiErrorResponse(w, "path can not be empty", http.StatusInternalServerError, nil) + apiErrorResponse(w, "path can not be empty", http.StatusBadRequest, nil) return } // Validate the path before it's spliced into the osquery SQL via @@ -209,7 +270,6 @@ func (h *HandlersApi) CarvesRunHandler(w http.ResponseWriter, r *http.Request) { if c.ExpHours == 0 { expTime = time.Time{} } - // Prepare and create new carve newQuery := queries.DistributedQuery{ Query: carves.GenCarveQuery(c.Path, false), Name: carves.GenCarveName(), @@ -224,7 +284,6 @@ func (h *HandlersApi) CarvesRunHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "error creating query", http.StatusInternalServerError, err) return } - // Prepare data for the handler code data := handlers.ProcessingQuery{ Envs: c.Environments, Platforms: c.Platforms, @@ -244,7 +303,6 @@ func (h *HandlersApi) CarvesRunHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "error creating query", http.StatusInternalServerError, err) return } - // If the list is empty, we don't need to create node queries if len(targetNodesID) != 0 { if err := h.Queries.CreateNodeQueries(targetNodesID, newQuery.ID); err != nil { log.Err(err).Msgf("error creating node queries for carve %s", newQuery.Name) @@ -252,54 +310,45 @@ func (h *HandlersApi) CarvesRunHandler(w http.ResponseWriter, r *http.Request) { return } } - // Update value for expected if err := h.Queries.SetExpected(newQuery.Name, len(targetNodesID), env.ID); err != nil { apiErrorResponse(w, "error setting expected", http.StatusInternalServerError, err) return } - // Return query name as serialized response - log.Debug().Msgf("Created query %s", newQuery.Name) + log.Debug().Msgf("Created carve %s", newQuery.Name) h.AuditLog.NewCarve(ctx[ctxUser], newQuery.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiQueriesResponse{Name: newQuery.Name}) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusCreated, types.ApiQueriesResponse{Name: newQuery.Name}) } -// CarvesActionHandler - POST Handler to delete/expire a carve +// CarvesActionHandler - POST /api/v1/carves/{env}/{action}/{name} func (h *HandlersApi) CarvesActionHandler(w http.ResponseWriter, r *http.Request) { - // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) } - // Extract environment envVar := r.PathValue("env") if envVar == "" { apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) return } - // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return } - // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } var msgReturn string - // Carve can not be empty nameVar := r.PathValue("name") if nameVar == "" { apiErrorResponse(w, "name can not be empty", http.StatusBadRequest, nil) return } - // Check if carve exists if !h.Queries.Exists(nameVar, env.ID) { apiErrorResponse(w, "carve not found", http.StatusNotFound, nil) return } - // Extract action actionVar := r.PathValue("action") if actionVar == "" { apiErrorResponse(w, "error getting action", http.StatusBadRequest, nil) @@ -324,9 +373,208 @@ func (h *HandlersApi) CarvesActionHandler(w http.ResponseWriter, r *http.Request return } msgReturn = fmt.Sprintf("carve %s completed successfully", nameVar) + default: + apiErrorResponse(w, "invalid action", http.StatusBadRequest, nil) + return } - // Return message as serialized response log.Debug().Msgf("%s", msgReturn) h.AuditLog.CarveAction(ctx[ctxUser], actionVar+" carve "+nameVar, strings.Split(r.RemoteAddr, ":")[0], env.ID) utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiGenericResponse{Message: msgReturn}) } + +// CarveArchiveHandler - GET /api/v1/carves/{env}/archive/{name} +// +// (The literal `archive` lives in segment 2 — not as a `/{name}/archive` suffix — +// because Go's ServeMux refuses to register patterns that ambiguously overlap with +// `/{env}/queries/{target}` registered on the same prefix.) +// +// Streams (or redirects to) the reassembled carve archive blob. +// +// Resolution rules: +// - The carve query identified by {name} must exist and be type=carve. +// - If exactly one CarvedFile exists for the query, it is served. +// - If multiple exist, an explicit ?session= must select one. +// A missing/ambiguous session selector returns 409 Conflict. +// - If the underlying file is not yet archived, it is archived on demand +// (local or DB carver: written to a temp dir, then served; S3: a presigned +// download URL is returned via 302 redirect). +// +// Content-Disposition is set to attachment with the carve archive filename. +func (h *HandlersApi) CarveArchiveHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + envVar := r.PathValue("env") + name := r.PathValue("name") + if envVar == "" || name == "" { + apiErrorResponse(w, "missing env or name", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) + return + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.CarveLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + + // Confirm the carve query exists and is a carve. + q, err := h.Queries.Get(name, env.ID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "carve not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error getting carve", http.StatusInternalServerError, err) + return + } + if q.Type != queries.CarveQueryType { + apiErrorResponse(w, "carve not found", http.StatusNotFound, nil) + return + } + + files, err := h.Carves.GetByQuery(name, env.ID) + if err != nil { + apiErrorResponse(w, "error getting carve files", http.StatusInternalServerError, err) + return + } + if len(files) == 0 { + apiErrorResponse(w, "no carved files yet", http.StatusNotFound, nil) + return + } + + requestedSession := strings.TrimSpace(r.URL.Query().Get("session")) + var selected *carves.CarvedFile + switch { + case requestedSession != "": + for i := range files { + if files[i].SessionID == requestedSession { + selected = &files[i] + break + } + } + if selected == nil { + apiErrorResponse(w, "session not found for carve", http.StatusNotFound, nil) + return + } + case len(files) == 1: + selected = &files[0] + default: + // Ambiguous — the caller must pick a session. + sessions := make([]string, 0, len(files)) + for _, f := range files { + sessions = append(sessions, f.SessionID) + } + apiErrorResponse(w, + fmt.Sprintf("carve has %d files; pass ?session= to select one (sessions: %s)", + len(files), strings.Join(sessions, ", ")), + http.StatusConflict, nil) + return + } + + // Materialize the archive if not already done. The path persistence + // strategy differs by carver: + // + // - S3: Archive() multipart-uploads the file to a persistent S3 + // key; we mark the row archived with that key and serve + // a presigned download URL. + // - Local/DB: Archive() reconstructs the file in a workspace dir. The + // API process owns no canonical "carves folder" — the + // legacy admin owns one — so we stage in a per-request + // tmpdir, stream, and do NOT persist the path. (Persisting + // would point future requests at a tmpdir we've already + // removed.) The trade-off is re-archiving on each request + // for local/DB carvers, which is correctness over cache. + carve := *selected + + if h.Carves.Carver == config.CarverS3 { + if !carve.Archived { + // Pass empty destPath — Archive() ignores it for the S3 path. + result, aerr := h.Carves.Archive(carve.SessionID, "") + if aerr != nil { + apiErrorResponse(w, "error archiving carve", http.StatusInternalServerError, aerr) + return + } + if result == nil { + apiErrorResponse(w, "empty carve archive", http.StatusInternalServerError, nil) + return + } + if aerr := h.Carves.ArchiveCarve(carve.SessionID, result.File); aerr != nil { + log.Err(aerr).Msgf("error marking carve %s archived", carve.SessionID) + } + carve.Archived = true + carve.ArchivePath = result.File + } + link, lerr := h.Carves.S3.GetDownloadLink(carve) + if lerr != nil { + apiErrorResponse(w, "error generating download link", http.StatusInternalServerError, lerr) + return + } + h.AuditLog.CarveAction(ctx[ctxUser], "download "+name, strings.Split(r.RemoteAddr, ":")[0], env.ID) + http.Redirect(w, r, link, http.StatusFound) + return + } + + // Local / DB carver: stage the archive in a per-request tmpdir and stream + // it back. RemoveAll runs after f.Close (defers are LIFO), so the file is + // readable for the duration of the response. + // + // os.MkdirTemp creates the directory mode 0700, but the file written + // inside by Carves.Archive may end up world-readable depending on + // the platform umask. We chmod it to 0600 explicitly so on a + // multi-tenant container host another tenant on the same node can't + // read the carved bytes during the brief window before RemoveAll. + // + archivePath := carve.ArchivePath + if !carve.Archived { + tmpDir, terr := os.MkdirTemp("", "osctrl-carve-archive-") + if terr != nil { + apiErrorResponse(w, "error preparing archive workspace", http.StatusInternalServerError, terr) + return + } + defer os.RemoveAll(tmpDir) + result, aerr := h.Carves.Archive(carve.SessionID, tmpDir) + if aerr != nil { + apiErrorResponse(w, "error archiving carve", http.StatusInternalServerError, aerr) + return + } + if result == nil { + apiErrorResponse(w, "empty carve archive", http.StatusInternalServerError, nil) + return + } + archivePath = result.File + if err := os.Chmod(archivePath, 0600); err != nil { + log.Err(err).Msgf("failed to chmod 0600 on carve archive %s — proceeding but file may be wider-readable", archivePath) + } + } + + f, ferr := os.Open(archivePath) + if ferr != nil { + apiErrorResponse(w, "error opening archive", http.StatusInternalServerError, ferr) + return + } + defer f.Close() + stat, serr := f.Stat() + if serr != nil { + apiErrorResponse(w, "error stat archive", http.StatusInternalServerError, serr) + return + } + filename := carves.GenerateArchiveName(carve) + // If the on-disk file picked up the zst suffix during archive, preserve it. + if strings.HasSuffix(archivePath, carves.ZstFileExtension) && + !strings.HasSuffix(filename, carves.ZstFileExtension) { + filename += carves.ZstFileExtension + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Length", strconv.FormatInt(stat.Size(), 10)) + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename)) + w.WriteHeader(http.StatusOK) + if _, err := io.Copy(w, f); err != nil { + log.Err(err).Msgf("error streaming carve archive %s", archivePath) + return + } + h.AuditLog.CarveAction(ctx[ctxUser], "download "+name, strings.Split(r.RemoteAddr, ":")[0], env.ID) +} diff --git a/cmd/api/handlers/environments.go b/cmd/api/handlers/environments.go index 6feb721d..f2057bc3 100644 --- a/cmd/api/handlers/environments.go +++ b/cmd/api/handlers/environments.go @@ -76,7 +76,7 @@ func (h *HandlersApi) EnvironmentHandler(w http.ResponseWriter, r *http.Request) return } // Get environment by UUID - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "environment not found", http.StatusNotFound, err) @@ -186,7 +186,7 @@ func (h *HandlersApi) EnvEnrollHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment by name - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "environment not found", http.StatusNotFound, err) @@ -256,7 +256,7 @@ func (h *HandlersApi) EnvRemoveHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment by name - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "environment not found", http.StatusNotFound, err) @@ -317,7 +317,7 @@ func (h *HandlersApi) EnvEnrollActionsHandler(w http.ResponseWriter, r *http.Req return } // Get environment by name - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "environment not found", http.StatusNotFound, err) @@ -417,7 +417,7 @@ func (h *HandlersApi) EnvRemoveActionsHandler(w http.ResponseWriter, r *http.Req return } // Get environment by name - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "environment not found", http.StatusNotFound, err) @@ -506,6 +506,7 @@ func (h *HandlersApi) EnvActionsHandler(w http.ResponseWriter, r *http.Request) return } // Validate the optional client-supplied UUID strictly. + // // - utils.CheckUUID delegates to google/uuid Parse, accepting only // canonical UUIDs. EnvUUIDFilter alone is `^[a-z0-9-]+$`, which // would have happily accepted "-", "a", "deadbeef", etc. diff --git a/cmd/api/handlers/environments_crud.go b/cmd/api/handlers/environments_crud.go new file mode 100644 index 00000000..11b5898a --- /dev/null +++ b/cmd/api/handlers/environments_crud.go @@ -0,0 +1,506 @@ +package handlers + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/jmpsec/osctrl/pkg/environments" + "github.com/jmpsec/osctrl/pkg/tags" + "github.com/jmpsec/osctrl/pkg/types" + "github.com/jmpsec/osctrl/pkg/users" + "github.com/jmpsec/osctrl/pkg/utils" + "github.com/rs/zerolog/log" + "gorm.io/gorm" +) + +// EnvironmentCreateHandler - POST /api/v1/environments +// +// Body: { name, hostname, type? }. Generates a UUID, defaults config / +// schedule / packs / decorators / ATC to "{}", and persists the env. +// Returns 201 with the created TLSEnvironment. Super-admin only. +func (h *HandlersApi) EnvironmentCreateHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + var body types.EnvCreateRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing POST body", http.StatusBadRequest, err) + return + } + body.Name = strings.TrimSpace(body.Name) + body.Hostname = strings.TrimSpace(body.Hostname) + if !environments.VerifyEnvFilters(body.Name, body.Icon, body.Type, body.Hostname) { + apiErrorResponse(w, "invalid name, hostname, type, or icon", http.StatusBadRequest, nil) + return + } + if h.Envs.Exists(body.Name) { + apiErrorResponse(w, "environment with that name already exists", http.StatusConflict, nil) + return + } + env := h.Envs.Empty(body.Name, body.Hostname) + if body.Type != "" { + env.Type = body.Type + } + if body.Icon != "" { + env.Icon = body.Icon + } + env.Configuration = h.Envs.GenEmptyConfiguration(true) + flags, err := h.Envs.GenerateFlags(env, "", "", h.OsqueryValues) + if err != nil { + apiErrorResponse(w, "error generating flags", http.StatusInternalServerError, err) + return + } + env.Flags = flags + if err := h.Envs.Create(&env); err != nil { + apiErrorResponse(w, "error creating environment", http.StatusInternalServerError, err) + return + } + // Grant the creating user full access to the new environment so it shows up + // in their env list immediately (matches the legacy admin behaviour). + access := h.Users.GenEnvUserAccess([]string{env.UUID}, true, true, true, true) + perms := h.Users.GenPermissions(ctx[ctxUser], h.ServiceName, access) + if err := h.Users.CreatePermissions(perms); err != nil { + log.Err(err).Msgf("env %s created but failed to grant creator permissions", env.Name) + } + // Auto-tag the environment for tag-based targeting. + if !h.Tags.ExistsByEnv(env.Name, env.ID) { + if err := h.Tags.NewTag( + env.Name, + "Tag for environment "+env.Name, + "", + env.Icon, + ctx[ctxUser], + env.ID, + false, + tags.TagTypeEnv, + "", + ); err != nil { + log.Err(err).Msgf("env %s created but failed to create env tag", env.Name) + } + } + h.AuditLog.EnvAction(ctx[ctxUser], "create env "+env.Name, strings.Split(r.RemoteAddr, ":")[0], env.ID) + log.Debug().Msgf("Created environment %s (uuid=%s)", env.Name, env.UUID) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusCreated, env) +} + +// EnvironmentUpdateHandler - PATCH /api/v1/environments/{env} +// +// Updates name / hostname / type / icon / debug_http / accept_enrolls. +// Other env fields go through the per-section endpoints. Super-admin only. +func (h *HandlersApi) EnvironmentUpdateHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "missing env", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "environment not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, err) + return + } + var body types.EnvUpdateRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing PATCH body", http.StatusBadRequest, err) + return + } + // Validate every supplied field with the same character-class + // filters the create path uses. Without this gate a super-admin + // (or a compromised super-admin session via a future CSRF gap) + // can PATCH the env name to anything — including shell + // metacharacters and newlines that downstream interpolators + // (genPackageFilename → Content-Disposition, audit-log lines, + // route paths) would happily embed unescaped. + // + patch := map[string]interface{}{} + if body.Name != nil { + n := strings.TrimSpace(*body.Name) + if !environments.EnvNameFilter(n) { + apiErrorResponse(w, "invalid environment name", http.StatusBadRequest, fmt.Errorf("rejected name %q", *body.Name)) + return + } + if n != env.Name { + patch["name"] = n + } + } + if body.Hostname != nil { + host := strings.TrimSpace(*body.Hostname) + if !environments.HostnameFilter(host) { + apiErrorResponse(w, "invalid hostname", http.StatusBadRequest, fmt.Errorf("rejected hostname %q", *body.Hostname)) + return + } + if host != env.Hostname { + patch["hostname"] = host + } + } + if body.Type != nil { + t := strings.TrimSpace(*body.Type) + if !environments.EnvTypeFilter(t) { + apiErrorResponse(w, "invalid environment type", http.StatusBadRequest, fmt.Errorf("rejected type %q", *body.Type)) + return + } + patch["type"] = t + } + if body.Icon != nil { + icon := strings.TrimSpace(*body.Icon) + if !environments.IconFilter(icon) { + apiErrorResponse(w, "invalid icon", http.StatusBadRequest, fmt.Errorf("rejected icon %q", *body.Icon)) + return + } + patch["icon"] = icon + } + if body.DebugHTTP != nil { + patch["debug_http"] = *body.DebugHTTP + } + if body.AcceptEnrolls != nil { + patch["accept_enrolls"] = *body.AcceptEnrolls + } + if len(patch) == 0 { + // Idempotent no-op — return the current env. + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, env) + return + } + if err := h.Envs.DB.Model(&env).Updates(patch).Error; err != nil { + apiErrorResponse(w, "error updating environment", http.StatusInternalServerError, err) + return + } + updated, _ := h.Envs.Get(envVar) + h.AuditLog.EnvAction(ctx[ctxUser], "update env "+env.Name, strings.Split(r.RemoteAddr, ":")[0], env.ID) + log.Debug().Msgf("Updated environment %s", env.Name) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, updated) +} + +// EnvironmentDeleteHandler - DELETE /api/v1/environments/{env} +// +// Removes the environment. Super-admin only. Returns 200 with a message. +func (h *HandlersApi) EnvironmentDeleteHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "missing env", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "environment not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, err) + return + } + if err := h.Envs.Delete(envVar); err != nil { + apiErrorResponse(w, "error deleting environment", http.StatusInternalServerError, err) + return + } + h.AuditLog.EnvAction(ctx[ctxUser], "delete env "+env.Name, strings.Split(r.RemoteAddr, ":")[0], env.ID) + log.Debug().Msgf("Deleted environment %s", env.Name) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiGenericResponse{Message: fmt.Sprintf("environment %s deleted", env.Name)}) +} + +// EnvironmentConfigHandler - GET /api/v1/environments/config/{env} +// +// Returns the env's JSON-shaped config sections (options/schedule/packs/ +// decorators/atc/flags) so the SPA's Monaco editor can render each section. +func (h *HandlersApi) EnvironmentConfigHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "missing env", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "environment not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, err) + return + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + resp := types.EnvConfigResponse{ + Options: env.Options, + Schedule: env.Schedule, + Packs: env.Packs, + Decorators: env.Decorators, + ATC: env.ATC, + Flags: env.Flags, + } + h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, resp) +} + +// EnvironmentConfigPatchHandler - PATCH /api/v1/environments/config/{env} +// +// Body: optional options/schedule/packs/decorators/atc/flags string fields. +// Each non-nil field is validated as JSON before persisting; an invalid +// payload is rejected with 400 (no partial writes). +func (h *HandlersApi) EnvironmentConfigPatchHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "missing env", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "environment not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, err) + return + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + var body types.EnvConfigPatchRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing PATCH body", http.StatusBadRequest, err) + return + } + // Validate every supplied section is parseable JSON before writing any. + sections := map[string]*string{ + "options": body.Options, + "schedule": body.Schedule, + "packs": body.Packs, + "decorators": body.Decorators, + "atc": body.ATC, + "flags": body.Flags, + } + for name, val := range sections { + if val == nil { + continue + } + // Empty string isn't valid JSON; treat as the empty object. + s := strings.TrimSpace(*val) + if s == "" { + s = "{}" + } + var probe interface{} + if err := json.Unmarshal([]byte(s), &probe); err != nil { + apiErrorResponse(w, fmt.Sprintf("section %q is not valid JSON: %s", name, err.Error()), http.StatusBadRequest, err) + return + } + } + if body.Options != nil { + if err := h.Envs.UpdateOptions(envVar, *body.Options); err != nil { + apiErrorResponse(w, "error updating options", http.StatusInternalServerError, err) + return + } + } + if body.Schedule != nil { + if err := h.Envs.UpdateSchedule(envVar, *body.Schedule); err != nil { + apiErrorResponse(w, "error updating schedule", http.StatusInternalServerError, err) + return + } + } + if body.Packs != nil { + if err := h.Envs.UpdatePacks(envVar, *body.Packs); err != nil { + apiErrorResponse(w, "error updating packs", http.StatusInternalServerError, err) + return + } + } + if body.Decorators != nil { + if err := h.Envs.UpdateDecorators(envVar, *body.Decorators); err != nil { + apiErrorResponse(w, "error updating decorators", http.StatusInternalServerError, err) + return + } + } + if body.ATC != nil { + if err := h.Envs.UpdateATC(envVar, *body.ATC); err != nil { + apiErrorResponse(w, "error updating atc", http.StatusInternalServerError, err) + return + } + } + if body.Flags != nil { + if err := h.Envs.DB.Model(&env).Update("flags", *body.Flags).Error; err != nil { + apiErrorResponse(w, "error updating flags", http.StatusInternalServerError, err) + return + } + } + h.AuditLog.ConfAction(ctx[ctxUser], "config patch on env "+env.Name, strings.Split(r.RemoteAddr, ":")[0], env.ID) + updated, _ := h.Envs.Get(envVar) + resp := types.EnvConfigResponse{ + Options: updated.Options, + Schedule: updated.Schedule, + Packs: updated.Packs, + Decorators: updated.Decorators, + ATC: updated.ATC, + Flags: updated.Flags, + } + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, resp) +} + +// EnvironmentIntervalsPatchHandler - PATCH /api/v1/environments/intervals/{env} +// +// Body: { config_interval?, log_interval?, query_interval? }. Updates the +// three node-pull intervals atomically. Unsupplied fields are kept. +func (h *HandlersApi) EnvironmentIntervalsPatchHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "missing env", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "environment not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, err) + return + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + var body types.EnvIntervalsPatchRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing PATCH body", http.StatusBadRequest, err) + return + } + cfg := env.ConfigInterval + lg := env.LogInterval + qr := env.QueryInterval + if body.ConfigInterval != nil { + if *body.ConfigInterval < 1 { + apiErrorResponse(w, "config_interval must be >= 1", http.StatusBadRequest, nil) + return + } + cfg = *body.ConfigInterval + } + if body.LogInterval != nil { + if *body.LogInterval < 1 { + apiErrorResponse(w, "log_interval must be >= 1", http.StatusBadRequest, nil) + return + } + lg = *body.LogInterval + } + if body.QueryInterval != nil { + if *body.QueryInterval < 1 { + apiErrorResponse(w, "query_interval must be >= 1", http.StatusBadRequest, nil) + return + } + qr = *body.QueryInterval + } + if err := h.Envs.UpdateIntervals(env.Name, cfg, lg, qr); err != nil { + apiErrorResponse(w, "error updating intervals", http.StatusInternalServerError, err) + return + } + h.AuditLog.ConfAction(ctx[ctxUser], + fmt.Sprintf("intervals patch on env %s: config=%d log=%d query=%d", env.Name, cfg, lg, qr), + strings.Split(r.RemoteAddr, ":")[0], env.ID) + updated, _ := h.Envs.Get(envVar) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, updated) +} + +// EnvironmentExpirationPatchHandler - PATCH /api/v1/environments/expiration/{env} +// +// Convenience wrapper around the existing enrollment lifecycle actions +// (extend / expire / rotate / not-expire), accepting one of those actions +// via JSON body instead of as a path segment. Mirrors the legacy +// EnvEnrollActionsHandler semantics for both enroll and remove paths. +func (h *HandlersApi) EnvironmentExpirationPatchHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "missing env", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "environment not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, err) + return + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + var body types.EnvExpirationPatchRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing PATCH body", http.StatusBadRequest, err) + return + } + switch body.Action { + case "extend": + if err := h.Envs.ExtendEnroll(env.UUID); err != nil { + apiErrorResponse(w, "error extending enrollment", http.StatusInternalServerError, err) + return + } + case "expire": + if err := h.Envs.ExpireEnroll(env.UUID); err != nil { + apiErrorResponse(w, "error expiring enrollment", http.StatusInternalServerError, err) + return + } + case "rotate": + if err := h.Envs.RotateEnroll(env.UUID); err != nil { + apiErrorResponse(w, "error rotating enrollment", http.StatusInternalServerError, err) + return + } + case "not-expire": + if err := h.Envs.NotExpireEnroll(env.UUID); err != nil { + apiErrorResponse(w, "error setting no expiration", http.StatusInternalServerError, err) + return + } + default: + apiErrorResponse(w, "action must be one of: extend, expire, rotate, not-expire", http.StatusBadRequest, nil) + return + } + h.AuditLog.EnvAction(ctx[ctxUser], body.Action+" enrollment for env "+env.Name, strings.Split(r.RemoteAddr, ":")[0], env.ID) + updated, _ := h.Envs.Get(envVar) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, updated) +} + +// Suppress unused-import warning if environments package isn't referenced +// elsewhere in this file (it is — used by EnvUpdateRequest typing). This +// stub is a no-op kept to keep the import obvious. +var _ = environments.EnrollShell diff --git a/cmd/api/handlers/environments_test.go b/cmd/api/handlers/environments_test.go index bbe332cf..6ed775c7 100644 --- a/cmd/api/handlers/environments_test.go +++ b/cmd/api/handlers/environments_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/jmpsec/osctrl/pkg/environments" - "gorm.io/gorm" ) // TestProjectEnvironmentViewStripsSecrets is the load-bearing regression test @@ -19,16 +18,14 @@ import ( // known-sensitive substring is absent from the serialized JSON. func TestProjectEnvironmentViewStripsSecrets(t *testing.T) { src := environments.TLSEnvironment{ - Model: gorm.Model{ - ID: 1, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - }, - UUID: "11111111-2222-3333-4444-555555555555", - Name: "prod", - Hostname: "osctrl.example.com", - Type: "dev", - Icon: "rocket", + ID: 1, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + UUID: "11111111-2222-3333-4444-555555555555", + Name: "prod", + Hostname: "osctrl.example.com", + Type: "dev", + Icon: "rocket", // The fields below must NOT appear in the projection. Secret: "SECRET-MARKER-enroll", EnrollSecretPath: "SECRET-MARKER-enroll-path", diff --git a/cmd/api/handlers/handlers.go b/cmd/api/handlers/handlers.go index 4b6b5b85..dea325ef 100644 --- a/cmd/api/handlers/handlers.go +++ b/cmd/api/handlers/handlers.go @@ -11,6 +11,7 @@ import ( "github.com/jmpsec/osctrl/pkg/queries" "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/tags" + "github.com/jmpsec/osctrl/pkg/types" "github.com/jmpsec/osctrl/pkg/users" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -36,6 +37,7 @@ type HandlersApi struct { ApiConfig *config.APIConfiguration DebugHTTP *zerolog.Logger DebugHTTPConfig *config.YAMLConfigurationDebug + OsqueryTables []types.OsqueryTable OsqueryValues config.YAMLConfigurationOsquery } @@ -112,12 +114,19 @@ func WithAuditLog(auditLog *auditlog.AuditLogManager) HandlersOption { h.AuditLog = auditLog } } + func WithOsqueryValues(values config.YAMLConfigurationOsquery) HandlersOption { return func(h *HandlersApi) { h.OsqueryValues = values } } +func WithOsqueryTables(tables []types.OsqueryTable) HandlersOption { + return func(h *HandlersApi) { + h.OsqueryTables = tables + } +} + func WithDebugHTTP(cfg *config.YAMLConfigurationDebug) HandlersOption { return func(h *HandlersApi) { h.DebugHTTPConfig = cfg diff --git a/cmd/api/handlers/login_envs.go b/cmd/api/handlers/login_envs.go new file mode 100644 index 00000000..b1b5c729 --- /dev/null +++ b/cmd/api/handlers/login_envs.go @@ -0,0 +1,48 @@ +package handlers + +import ( + "net/http" + + "github.com/jmpsec/osctrl/pkg/types" + "github.com/jmpsec/osctrl/pkg/utils" +) + +// LoginEnvironmentsHandler - GET /api/v1/login/environments +// +// Pre-auth endpoint that returns the list of environments the user may attempt +// to log into. Surface is intentionally minimal: only the env UUID and name. +// No enroll secrets, no certificates, no settings, no hostnames — those all +// stay behind auth on /api/v1/environments and its CRUD siblings. +// +// Rationale: forcing the user to type the env name on the login screen is bad +// UX (you don't know it until you've logged in once, and single-env installs +// only ever have one option). The legacy admin shows env names pre-auth in its +// login form, so we're not changing the security posture — just exposing the +// same identifiers that the URL space already commits to using post-auth. +// +// Like POST /login/{env}, this lives behind the per-IP rate limit registered +// in main.go so the endpoint can't be turned into an env-enumeration oracle +// for brute-force prep beyond the limit. +func (h *HandlersApi) LoginEnvironmentsHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + envs, err := h.Envs.All() + if err != nil { + apiErrorResponse(w, "error listing environments", http.StatusInternalServerError, err) + return + } + // Project to (uuid, name) only. Constructing the response explicitly + // guards against future fields being added to TLSEnvironment that + // shouldn't be exposed pre-auth — if someone adds e.g. a `Secret` field + // to that struct later, this handler still ships only the two fields + // listed here. + out := make([]types.LoginEnvironment, 0, len(envs)) + for _, e := range envs { + out = append(out, types.LoginEnvironment{ + UUID: e.UUID, + Name: e.Name, + }) + } + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, out) +} diff --git a/cmd/api/handlers/logs.go b/cmd/api/handlers/logs.go new file mode 100644 index 00000000..8f71250b --- /dev/null +++ b/cmd/api/handlers/logs.go @@ -0,0 +1,124 @@ +package handlers + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/jmpsec/osctrl/pkg/logging" + "github.com/jmpsec/osctrl/pkg/types" + "github.com/jmpsec/osctrl/pkg/users" + "github.com/jmpsec/osctrl/pkg/utils" + "github.com/rs/zerolog/log" +) + +// NodeLogsResponse is the SPA-canonical response for GET /api/v1/logs/{type}/{env}/{uuid}. +type NodeLogsResponse struct { + Items []map[string]any `json:"items"` + Type string `json:"type"` + UUID string `json:"uuid"` + Env string `json:"env"` + Since string `json:"since,omitempty"` + Limit int `json:"limit"` +} + +// NodeLogsHandler returns recent log entries for a node. +// +// Path: /api/v1/logs/{type}/{env}/{uuid} +// +// type: "status" | "result" +// env: UUID or name +// uuid: node UUID +// +// Query params: +// +// since: RFC3339 timestamp; entries strictly after this point only +// limit: 1..1000 (default 100) +func (h *HandlersApi) NodeLogsHandler(w http.ResponseWriter, r *http.Request) { + // Debug HTTP if enabled + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + logType := r.PathValue("type") + switch logType { + case types.StatusLog, types.ResultLog: + default: + apiErrorResponse(w, "invalid log type (status|result)", http.StatusBadRequest, nil) + return + } + envVar := r.PathValue("env") + nodeUUID := r.PathValue("uuid") + + env, err := h.Envs.Get(envVar) + if err != nil { + envByName, err2 := h.Envs.GetByName(envVar) + if err2 != nil { + apiErrorResponse(w, "environment not found", http.StatusNotFound, err) + return + } + env = envByName + } + + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.UserLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, nil) + return + } + + // Verify the node exists in this env — prevents probing for arbitrary UUIDs + // across tenants (resolves cross-tenant log read vector). + node, err := h.Nodes.GetByUUID(nodeUUID) + if err != nil { + apiErrorResponse(w, "node not found", http.StatusNotFound, err) + return + } + if node.Environment == "" || !strings.EqualFold(node.Environment, env.Name) { + apiErrorResponse(w, "node not in environment", http.StatusForbidden, nil) + return + } + + q := r.URL.Query() + limit, _ := strconv.Atoi(q.Get("limit")) + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + var since time.Time + if s := q.Get("since"); s != "" { + t, err := time.Parse(time.RFC3339, s) + if err != nil { + apiErrorResponse(w, "invalid since (expected RFC3339)", http.StatusBadRequest, err) + return + } + since = t + } + // Optional free-text filter. Substring match against the log row's + // human-readable columns (line / message / filename for status logs; + // name / action / columns JSON for result logs). Server-side so + // operators can search the full history, not just the visible page. + search := strings.TrimSpace(q.Get("q")) + + // Use the node's canonical UUID (already upper-cased in the DB) from the + // verified node record, not the raw URL parameter. + items, err := logging.GetNodeLogs(h.DB, logType, env.Name, node.UUID, since, limit, search) + if err != nil { + apiErrorResponse(w, "failed to query logs", http.StatusInternalServerError, err) + return + } + if items == nil { + items = []map[string]any{} + } + + log.Debug().Msgf("Returned %d %s log entries for node %s", len(items), logType, node.UUID) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, NodeLogsResponse{ + Items: items, + Type: logType, + UUID: node.UUID, + Env: env.UUID, + Since: q.Get("since"), + Limit: limit, + }) +} diff --git a/cmd/api/handlers/nodes.go b/cmd/api/handlers/nodes.go index e1299ab7..b374d172 100644 --- a/cmd/api/handlers/nodes.go +++ b/cmd/api/handlers/nodes.go @@ -4,9 +4,11 @@ import ( "encoding/json" "fmt" "net/http" + "strconv" "strings" "github.com/jmpsec/osctrl/pkg/nodes" + "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/types" "github.com/jmpsec/osctrl/pkg/users" "github.com/jmpsec/osctrl/pkg/utils" @@ -26,7 +28,7 @@ func (h *HandlersApi) NodeHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -43,9 +45,8 @@ func (h *HandlersApi) NodeHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "error getting node", http.StatusBadRequest, nil) return } - // Get node by identifier - // FIXME keep a cache of nodes by node identifier - node, err := h.Nodes.GetByIdentifier(nodeVar) + // Get node by identifier, scoped to this environment + node, err := h.Nodes.GetByIdentifierEnv(nodeVar, env.ID) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "node not found", http.StatusNotFound, err) @@ -56,8 +57,11 @@ func (h *HandlersApi) NodeHandler(w http.ResponseWriter, r *http.Request) { } log.Debug().Msgf("Returned node %s", nodeVar) h.AuditLog.NodeAction(ctx[ctxUser], "viewed node "+nodeVar, strings.Split(r.RemoteAddr, ":")[0], env.ID) - // Serialize and serve JSON - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, node) + // Project to the SPA-facing view that surfaces parsed-and-sanitized + // enrichment fields (CPU cores, BIOS, hardware vendor/model) parsed from + // the otherwise-hidden RawEnrollment blob. The enroll_secret inside that + // blob is intentionally NOT in the projection — see pkg/types/node_view.go. + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ProjectNode(node)) } // ActiveNodesHandler - GET Handler for active JSON nodes @@ -73,7 +77,7 @@ func (h *HandlersApi) ActiveNodesHandler(w http.ResponseWriter, r *http.Request) return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -84,20 +88,21 @@ func (h *HandlersApi) ActiveNodesHandler(w http.ResponseWriter, r *http.Request) apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get nodes - nodes, err := h.Nodes.Gets(nodes.ActiveNodes, 24) + // Get nodes — scoped to this environment (resolves audit finding U-DB-2) + hours := h.Settings.InactiveHours(settings.NoEnvironmentID) + nodeList, err := h.Nodes.GetByEnv(env.Name, nodes.ActiveNodes, hours) if err != nil { apiErrorResponse(w, "error getting nodes", http.StatusInternalServerError, err) return } - if len(nodes) == 0 { + if len(nodeList) == 0 { apiErrorResponse(w, "no nodes", http.StatusNotFound, nil) return } // Serialize and serve JSON log.Debug().Msg("Returned active nodes") h.AuditLog.NodeAction(ctx[ctxUser], "viewed active nodes", strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, nodes) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, nodeList) } // InactiveNodesHandler - GET Handler for inactive JSON nodes @@ -113,7 +118,7 @@ func (h *HandlersApi) InactiveNodesHandler(w http.ResponseWriter, r *http.Reques return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -124,20 +129,21 @@ func (h *HandlersApi) InactiveNodesHandler(w http.ResponseWriter, r *http.Reques apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get nodes - nodes, err := h.Nodes.Gets(nodes.InactiveNodes, 24) + // Get nodes — scoped to this environment (resolves audit finding U-DB-2) + hours := h.Settings.InactiveHours(settings.NoEnvironmentID) + nodeList, err := h.Nodes.GetByEnv(env.Name, nodes.InactiveNodes, hours) if err != nil { apiErrorResponse(w, "error getting nodes", http.StatusInternalServerError, err) return } - if len(nodes) == 0 { + if len(nodeList) == 0 { apiErrorResponse(w, "no nodes", http.StatusNotFound, nil) return } // Serialize and serve JSON log.Debug().Msg("Returned inactive nodes") h.AuditLog.NodeAction(ctx[ctxUser], "viewed inactive nodes", strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, nodes) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, nodeList) } // AllNodesHandler - GET Handler for all JSON nodes @@ -153,7 +159,7 @@ func (h *HandlersApi) AllNodesHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusBadRequest, nil) return @@ -164,20 +170,20 @@ func (h *HandlersApi) AllNodesHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get nodes - nodes, err := h.Nodes.Gets(nodes.AllNodes, 0) + // Get nodes — scoped to this environment (resolves audit finding U-DB-2) + nodeList, err := h.Nodes.GetByEnv(env.Name, nodes.AllNodes, 0) if err != nil { apiErrorResponse(w, "error getting nodes", http.StatusInternalServerError, err) return } - if len(nodes) == 0 { + if len(nodeList) == 0 { apiErrorResponse(w, "no nodes", http.StatusNotFound, nil) return } // Serialize and serve JSON log.Debug().Msg("Returned all nodes") h.AuditLog.NodeAction(ctx[ctxUser], "viewed all nodes", strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, nodes) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, nodeList) } // DeleteNodeHandler - POST Handler to delete single node @@ -193,7 +199,7 @@ func (h *HandlersApi) DeleteNodeHandler(w http.ResponseWriter, r *http.Request) return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -237,7 +243,7 @@ func (h *HandlersApi) TagNodeHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -251,7 +257,11 @@ func (h *HandlersApi) TagNodeHandler(w http.ResponseWriter, r *http.Request) { var t types.ApiNodeTagRequest // Parse request JSON body if err := json.NewDecoder(r.Body).Decode(&t); err != nil { - apiErrorResponse(w, "error parsing POST body", http.StatusInternalServerError, err) + apiErrorResponse(w, "error parsing POST body", http.StatusBadRequest, err) + return + } + if t.UUID == "" || t.Tag == "" { + apiErrorResponse(w, "uuid and tag are required", http.StatusBadRequest, nil) return } // Get node by UUID @@ -310,3 +320,122 @@ func (h *HandlersApi) LookupNodeHandler(w http.ResponseWriter, r *http.Request) // Serialize and serve JSON utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, n) } + +// NodesPagedHandler returns paginated, sorted, searchable nodes for an env. +// This is the canonical endpoint consumed by the React admin SPA. +// +// Query params: +// +// status: "all" | "active" | "inactive" (default "all") +// q: free-text search (case-insensitive partial match on uuid, +// hostname, localname, ip, username, osquery_user, platform, version) +// sort: one of nodes.SortableColumns keys (default "lastseen") +// dir: "asc" | "desc" (default "desc" for lastseen, "asc" otherwise) +// page: 1-indexed page number (default 1) +// page_size: 1..500 (default 50) +func (h *HandlersApi) NodesPagedHandler(w http.ResponseWriter, r *http.Request) { + // Debug HTTP if enabled + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + // env from URL path + envVar := r.PathValue("env") + env, err := h.Envs.Get(envVar) + if err != nil { + // try by name for callers that pass an env name (legacy compat) + envByName, err2 := h.Envs.GetByName(envVar) + if err2 != nil { + apiErrorResponse(w, "environment not found", http.StatusNotFound, err) + return + } + env = envByName + } + + // auth context — user + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.UserLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + + // params + q := r.URL.Query() + status := q.Get("status") + if status == "" { + status = "all" + } + switch status { + case "all", "active", "inactive": + default: + apiErrorResponse(w, "invalid status (all|active|inactive)", http.StatusBadRequest, nil) + return + } + search := q.Get("q") + dirParam := strings.ToLower(q.Get("dir")) + sortCol := q.Get("sort") + var desc bool + switch dirParam { + case "asc": + desc = false + case "desc": + desc = true + default: + // No explicit direction: default to desc for time-based columns, + // asc for everything else. Matches OpenAPI default of "desc" for + // the most common SPA sort (lastseen). + switch sortCol { + case "", "lastseen", "firstseen": + desc = true + default: + desc = false + } + } + page, _ := strconv.Atoi(q.Get("page")) + pageSize, _ := strconv.Atoi(q.Get("page_size")) + + // Platform bucket filter — empty string disables. Validated inside + // applyPlatformBucket: unknown buckets become no-ops. We do still allow + // the explicit value "other" so the SPA can offer an "Other" chip for + // platforms that don't fit linux/darwin/windows. + platformBucket := strings.ToLower(strings.TrimSpace(q.Get("platform"))) + switch platformBucket { + case "", "linux", "darwin", "windows", "other": + // allowed + default: + apiErrorResponse(w, "invalid platform (linux|darwin|windows|other)", http.StatusBadRequest, nil) + return + } + + hours := h.Settings.InactiveHours(settings.NoEnvironmentID) + pageData, err := h.Nodes.GetByEnvPaged(env.Name, status, hours, search, page, pageSize, sortCol, desc, platformBucket) + if err != nil { + apiErrorResponse(w, "failed to query nodes", http.StatusInternalServerError, err) + return + } + + // Normalize page/pageSize back so the client sees what was actually applied. + if pageSize <= 0 { + pageSize = 50 + } else if pageSize > 500 { + pageSize = 500 + } + if page <= 0 { + page = 1 + } + totalPages := int((pageData.TotalItems + int64(pageSize) - 1) / int64(pageSize)) + if totalPages == 0 { + totalPages = 1 + } + + log.Debug().Msgf("Returned paged nodes for env %s page %d", env.Name, page) + h.AuditLog.NodeAction(ctx[ctxUser], "viewed paged nodes", strings.Split(r.RemoteAddr, ":")[0], env.ID) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.NodesPagedResponse{ + // ProjectNodes adds the parsed `system_info` enrichment block per row. + // The enroll_secret inside RawEnrollment is intentionally excluded. + Items: types.ProjectNodes(pageData.Items), + Page: page, + PageSize: pageSize, + TotalItems: pageData.TotalItems, + TotalPages: totalPages, + }) +} diff --git a/cmd/api/handlers/queries.go b/cmd/api/handlers/queries.go index 36f341a5..bf52bda2 100644 --- a/cmd/api/handlers/queries.go +++ b/cmd/api/handlers/queries.go @@ -1,13 +1,17 @@ package handlers import ( + "encoding/csv" "encoding/json" "fmt" "net/http" + "sort" + "strconv" "strings" "time" "github.com/jmpsec/osctrl/pkg/handlers" + "github.com/jmpsec/osctrl/pkg/logging" "github.com/jmpsec/osctrl/pkg/queries" "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/types" @@ -16,11 +20,13 @@ import ( "github.com/rs/zerolog/log" ) +// QueryTargets enumerates the target filters accepted by QueryListHandler. +// TargetHiddenActive is intentionally excluded: no UI tab references it and +// GetByEnvTargetPaged has no branch for it (mirrors Gets() which returns nothing). var QueryTargets = map[string]bool{ queries.TargetAll: true, queries.TargetAllFull: true, queries.TargetActive: true, - queries.TargetHiddenActive: true, queries.TargetCompleted: true, queries.TargetExpired: true, queries.TargetSaved: true, @@ -48,7 +54,7 @@ func (h *HandlersApi) QueryShowHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -88,7 +94,7 @@ func (h *HandlersApi) QueriesRunHandler(w http.ResponseWriter, r *http.Request) return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -196,7 +202,7 @@ func (h *HandlersApi) QueriesActionHandler(w http.ResponseWriter, r *http.Reques return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -264,7 +270,7 @@ func (h *HandlersApi) AllQueriesShowHandler(w http.ResponseWriter, r *http.Reque return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -291,7 +297,9 @@ func (h *HandlersApi) AllQueriesShowHandler(w http.ResponseWriter, r *http.Reque utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, queries) } -// QueryListHandler - GET Handler to return queries in JSON by target and environment +// QueryListHandler - GET Handler to return queries in JSON by target and environment (paginated) +// +// Query params: page, page_size, q (free-text search), sort (column key), dir (asc|desc) func (h *HandlersApi) QueryListHandler(w http.ResponseWriter, r *http.Request) { // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { @@ -304,7 +312,7 @@ func (h *HandlersApi) QueryListHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -326,23 +334,62 @@ func (h *HandlersApi) QueryListHandler(w http.ResponseWriter, r *http.Request) { apiErrorResponse(w, "invalid target", http.StatusBadRequest, nil) return } - // Get queries - queries, err := h.Queries.GetQueries(targetVar, env.ID) + // Parse pagination / search / sort params + q := r.URL.Query() + page, _ := strconv.Atoi(q.Get("page")) + pageSize, _ := strconv.Atoi(q.Get("page_size")) + search := q.Get("q") + sortCol := q.Get("sort") + desc := strings.ToLower(q.Get("dir")) != "asc" + + // Clamp pagination once at the handler so the response echoes effective + // values; the package function still clamps defensively. + if pageSize <= 0 { + pageSize = 50 + } + if pageSize > 500 { + pageSize = 500 + } + if page <= 0 { + page = 1 + } + + result, err := h.Queries.GetByEnvTargetPaged(env.ID, targetVar, queries.StandardQueryType, search, page, pageSize, sortCol, desc) if err != nil { apiErrorResponse(w, "error getting queries", http.StatusInternalServerError, err) return } - if len(queries) == 0 { - apiErrorResponse(w, "no queries", http.StatusNotFound, nil) - return + + // Empty result is a valid state — return HTTP 200 with empty items. + items := result.Items + if items == nil { + items = []queries.DistributedQuery{} + } + var totalPages int + if result.TotalItems > 0 { + totalPages = int((result.TotalItems + int64(pageSize) - 1) / int64(pageSize)) } + + resp := types.QueriesPagedResponse{ + Items: items, + Page: page, + PageSize: pageSize, + TotalItems: result.TotalItems, + TotalPages: totalPages, + } + // Serialize and serve JSON - log.Debug().Msgf("Returned %d queries", len(queries)) + log.Debug().Msgf("Returned %d queries (page %d of %d)", len(items), page, totalPages) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, queries) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, resp) } -// QueryResultsHandler - GET Handler to return a single query results in JSON +// QueryResultsHandler - GET Handler to return paginated query results in JSON +// +// Path: /api/v1/queries/{env}/results/{name} +// Params: page, page_size, since (RFC3339 timestamp; unparseable → ignored) +// +// Empty results are a valid state and return HTTP 200 with items: []. func (h *HandlersApi) QueryResultsHandler(w http.ResponseWriter, r *http.Request) { // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { @@ -357,11 +404,11 @@ func (h *HandlersApi) QueryResultsHandler(w http.ResponseWriter, r *http.Request // Extract environment envVar := r.PathValue("env") if envVar == "" { - apiErrorResponse(w, "error with environment", http.StatusInternalServerError, nil) + apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) return } // Get environment - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) return @@ -380,20 +427,175 @@ func (h *HandlersApi) QueryResultsHandler(w http.ResponseWriter, r *http.Request apiErrorResponse(w, "query not found", http.StatusNotFound, nil) return } - // Get query by name - // TODO this is a temporary solution, we need to refactor this and take into consideration the - // logger for TLS and whether if the results are stored in the DB or a different DB - queryLogs, err := postgresQueryLogs(h.DB, name) + + // Parse pagination + since cursor + q := r.URL.Query() + page, _ := strconv.Atoi(q.Get("page")) + pageSize, _ := strconv.Atoi(q.Get("page_size")) + if pageSize <= 0 { + pageSize = 100 + } + if pageSize > 1000 { + pageSize = 1000 + } + if page <= 0 { + page = 1 + } + var since time.Time + var sinceEcho string + if s := strings.TrimSpace(q.Get("since")); s != "" { + if t, perr := time.Parse(time.RFC3339, s); perr == nil { + since = t + sinceEcho = s + } + } + + items, total, err := logging.GetQueryResults(h.DB, name, since, page, pageSize) if err != nil { - if err.Error() == "record not found" { - apiErrorResponse(w, "query not found", http.StatusNotFound, err) - } else { - apiErrorResponse(w, "error getting query", http.StatusInternalServerError, err) + apiErrorResponse(w, "error getting query results", http.StatusInternalServerError, err) + return + } + if items == nil { + items = []map[string]any{} + } + var totalPages int + if total > 0 { + totalPages = int((total + int64(pageSize) - 1) / int64(pageSize)) + } + resp := types.QueryResultsResponse{ + Items: items, + Page: page, + PageSize: pageSize, + TotalItems: total, + TotalPages: totalPages, + Since: sinceEcho, + } + log.Debug().Msgf("Returned query results for %s (page %d of %d, %d rows)", name, page, totalPages, len(items)) + h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, resp) +} + +// QueryResultsCSVHandler - GET Handler to stream query results as CSV +// +// Path: /api/v1/queries/{env}/results/csv/{name} +// +// (The `.csv` lives as a literal path segment before `{name}` because Go's +// ServeMux grammar requires wildcards to end at `/` or end-of-pattern, so +// `{name}.csv` is a parse error at registration time.) +func (h *HandlersApi) QueryResultsCSVHandler(w http.ResponseWriter, r *http.Request) { + // Debug HTTP if enabled + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + name := r.PathValue("name") + if name == "" { + apiErrorResponse(w, "error getting name", http.StatusBadRequest, nil) + return + } + // Extract environment + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) + return + } + // Get environment + env, err := h.Envs.Get(envVar) + if err != nil { + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) + return + } + // Get context data and check access + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.QueryLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + // Verify the named query belongs to THIS env. See the matching gate + // in QueryResultsHandler for the rationale. + if !h.Queries.Exists(name, env.ID) { + apiErrorResponse(w, "query not found", http.StatusNotFound, nil) + return + } + // Pass 1 (streaming): walk every row, collect the union of column names. + // We only retain column names here — never the row data — to keep memory at O(columns). + colSet := make(map[string]struct{}) + if err := logging.StreamQueryResults(h.DB, name, func(row logging.OsqueryQueryData) error { + var cols map[string]string + if err := json.Unmarshal([]byte(row.Data), &cols); err != nil { + cols = map[string]string{"data": row.Data} } + for k := range cols { + colSet[k] = struct{}{} + } + return nil + }); err != nil { + apiErrorResponse(w, "error getting query results", http.StatusInternalServerError, err) return } - // Serialize and serve JSON - log.Debug().Msgf("Returned query results for %s", name) + headers := make([]string, 0, len(colSet)+1) + headers = append(headers, "uuid") + sortedCols := make([]string, 0, len(colSet)) + for k := range colSet { + sortedCols = append(sortedCols, k) + } + sort.Strings(sortedCols) + headers = append(headers, sortedCols...) + + // Set response headers BEFORE writing any body. + w.Header().Set("Content-Type", "text/csv; charset=utf-8") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name+".csv")) + + cw := csv.NewWriter(w) + flusher, _ := w.(http.Flusher) + if err := cw.Write(headers); err != nil { + log.Err(err).Msgf("error writing CSV header for %s", name) + return + } + cw.Flush() + if flusher != nil { + flusher.Flush() + } + + // Pass 2 (streaming): write data rows, flushing after each so bytes reach the client incrementally. + rowCount := 0 + if err := logging.StreamQueryResults(h.DB, name, func(row logging.OsqueryQueryData) error { + var cols map[string]string + if err := json.Unmarshal([]byte(row.Data), &cols); err != nil { + cols = map[string]string{"data": row.Data} + } + record := make([]string, len(headers)) + record[0] = row.UUID + for i, col := range sortedCols { + record[i+1] = cols[col] + } + if werr := cw.Write(record); werr != nil { + return werr + } + cw.Flush() + if flusher != nil { + flusher.Flush() + } + rowCount++ + return nil + }); err != nil { + // Headers already sent; we can only log and stop. + log.Err(err).Msgf("error streaming CSV rows for %s", name) + return + } + log.Debug().Msgf("Exported CSV for query %s (%d rows)", name, rowCount) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, queryLogs) +} + +// OsqueryTablesHandler - GET Handler to return the osquery schema tables +// +// Path: /api/v1/osquery/tables +// The schema is global (not env-scoped). Requires any authenticated user. +// Responses are cache-able for one hour since the schema rarely changes. +func (h *HandlersApi) OsqueryTablesHandler(w http.ResponseWriter, r *http.Request) { + // Debug HTTP if enabled + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + w.Header().Set("Cache-Control", "private, max-age=3600") + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, h.OsqueryTables) } diff --git a/cmd/api/handlers/samples.go b/cmd/api/handlers/samples.go new file mode 100644 index 00000000..78a3c9fd --- /dev/null +++ b/cmd/api/handlers/samples.go @@ -0,0 +1,38 @@ +package handlers + +import ( + "net/http" + + "github.com/jmpsec/osctrl/pkg/carves" + "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/utils" +) + +// QuerySamplesHandler - GET /api/v1/queries/samples +// +// Returns the static starter library of osquery SQL templates so the SPA's +// queries/new form can populate its QuickTemplates row. Intentionally +// unauthenticated: the samples are read-only data shipped with the binary, +// they aren't tenant- or env-scoped, and exposing them pre-auth lets the +// login screen lazy-load them without circular dependencies. +// +// Shares the per-IP loginRateLimit registered in main.go so this endpoint +// can't be turned into a low-effort scanning probe. +func (h *HandlersApi) QuerySamplesHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, queries.QuerySamples) +} + +// CarveSamplesHandler - GET /api/v1/carves/samples +// +// Returns the static starter library of common carve-target file paths +// (e.g., /etc/passwd, C:\Windows\System32\config\SAM). Same auth posture as +// QuerySamplesHandler: pre-auth, rate-limited. +func (h *HandlersApi) CarveSamplesHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, carves.CarveSamples) +} diff --git a/cmd/api/handlers/saved_queries.go b/cmd/api/handlers/saved_queries.go new file mode 100644 index 00000000..bdd6c72a --- /dev/null +++ b/cmd/api/handlers/saved_queries.go @@ -0,0 +1,257 @@ +package handlers + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/types" + "github.com/jmpsec/osctrl/pkg/users" + "github.com/jmpsec/osctrl/pkg/utils" + "github.com/rs/zerolog/log" + "gorm.io/gorm" +) + +// savedQueryView projects a storage row into the SPA-canonical envelope. +// Timestamps stay as time.Time so JSON-encoded output is RFC3339 — matches +// the OpenAPI date-time format and the SPA's formatRelative ISO parser. +func savedQueryView(s queries.SavedQuery) types.SavedQueryView { + return types.SavedQueryView{ + ID: s.ID, + CreatedAt: s.CreatedAt, + UpdatedAt: s.UpdatedAt, + Name: s.Name, + Creator: s.Creator, + Query: s.Query, + EnvironmentID: s.EnvironmentID, + ExtraData: s.ExtraData, + } +} + +// SavedQueriesListHandler - GET /api/v1/saved-queries/{env} +// +// Paginated, sorted, searchable list of saved queries for an environment. +// Query params: page, page_size, q (free-text), sort (column key), dir (asc|desc). +func (h *HandlersApi) SavedQueriesListHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) + return + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.QueryLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + + q := r.URL.Query() + page, _ := strconv.Atoi(q.Get("page")) + pageSize, _ := strconv.Atoi(q.Get("page_size")) + if pageSize <= 0 { + pageSize = 50 + } + if pageSize > 500 { + pageSize = 500 + } + if page <= 0 { + page = 1 + } + search := q.Get("q") + sortCol := q.Get("sort") + desc := strings.ToLower(q.Get("dir")) != "asc" + + result, err := h.Queries.GetSavedByEnvPaged(env.ID, search, page, pageSize, sortCol, desc) + if err != nil { + apiErrorResponse(w, "error getting saved queries", http.StatusInternalServerError, err) + return + } + items := make([]types.SavedQueryView, 0, len(result.Items)) + for _, s := range result.Items { + items = append(items, savedQueryView(s)) + } + var totalPages int + if result.TotalItems > 0 { + totalPages = int((result.TotalItems + int64(pageSize) - 1) / int64(pageSize)) + } + resp := types.SavedQueriesPagedResponse{ + Items: items, + Page: page, + PageSize: pageSize, + TotalItems: result.TotalItems, + TotalPages: totalPages, + } + log.Debug().Msgf("Returned %d saved queries (page %d of %d)", len(items), page, totalPages) + h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, resp) +} + +// SavedQueryCreateHandler - POST /api/v1/saved-queries/{env} +// +// Body: { "name": string, "query": string }. Returns 201 with the created view, +// 409 if a saved query with that name already exists in the environment. +func (h *HandlersApi) SavedQueryCreateHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) + return + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.QueryLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + + var body types.SavedQueryCreateRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing POST body", http.StatusBadRequest, err) + return + } + body.Name = strings.TrimSpace(body.Name) + body.Query = strings.TrimSpace(body.Query) + if body.Name == "" { + apiErrorResponse(w, "name can not be empty", http.StatusBadRequest, nil) + return + } + if body.Query == "" { + apiErrorResponse(w, "query can not be empty", http.StatusBadRequest, nil) + return + } + // The DB unique index on (name, environment_id) is the authoritative + // gate (see pkg/queries.SavedQuery + ErrSavedQueryExists). The + // SavedExists probe stays as a fast-path so the typical "this name + // is already taken" case returns 409 without hitting Create at all; + // races where two POSTs slip past SavedExists are caught by the + // duplicate-key error from CreateSaved. + if h.Queries.SavedExists(body.Name, env.ID) { + apiErrorResponse(w, "saved query with that name already exists", http.StatusConflict, nil) + return + } + + creator := ctx[ctxUser] + if err := h.Queries.CreateSaved(body.Name, body.Query, creator, env.ID); err != nil { + if errors.Is(err, queries.ErrSavedQueryExists) { + apiErrorResponse(w, "saved query with that name already exists", http.StatusConflict, err) + return + } + apiErrorResponse(w, "error creating saved query", http.StatusInternalServerError, err) + return + } + saved, err := h.Queries.GetSavedByEnv(body.Name, env.ID) + if err != nil { + apiErrorResponse(w, "error fetching newly created saved query", http.StatusInternalServerError, err) + return + } + + h.AuditLog.SavedQueryAction(creator, "create "+body.Name, strings.Split(r.RemoteAddr, ":")[0], env.ID) + log.Debug().Msgf("Created saved query %s in env %s", body.Name, env.UUID) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusCreated, savedQueryView(saved)) +} + +// SavedQueryUpdateHandler - PATCH /api/v1/saved-queries/{env}/{name} +// +// Body: { "query": string }. Updates the SQL body only; the original creator +// is preserved. Returns the updated view. +func (h *HandlersApi) SavedQueryUpdateHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + envVar := r.PathValue("env") + name := r.PathValue("name") + if envVar == "" || name == "" { + apiErrorResponse(w, "missing env or name", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) + return + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.QueryLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + + var body types.SavedQueryUpdateRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing PATCH body", http.StatusBadRequest, err) + return + } + body.Query = strings.TrimSpace(body.Query) + if body.Query == "" { + apiErrorResponse(w, "query can not be empty", http.StatusBadRequest, nil) + return + } + + if err := h.Queries.UpdateSaved(name, body.Query, env.ID); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "saved query not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error updating saved query", http.StatusInternalServerError, err) + return + } + saved, err := h.Queries.GetSavedByEnv(name, env.ID) + if err != nil { + apiErrorResponse(w, "error fetching updated saved query", http.StatusInternalServerError, err) + return + } + h.AuditLog.SavedQueryAction(ctx[ctxUser], "update "+name, strings.Split(r.RemoteAddr, ":")[0], env.ID) + log.Debug().Msgf("Updated saved query %s in env %s", name, env.UUID) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, savedQueryView(saved)) +} + +// SavedQueryDeleteHandler - DELETE /api/v1/saved-queries/{env}/{name} +func (h *HandlersApi) SavedQueryDeleteHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + envVar := r.PathValue("env") + name := r.PathValue("name") + if envVar == "" || name == "" { + apiErrorResponse(w, "missing env or name", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, nil) + return + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.QueryLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + + if err := h.Queries.DeleteSavedByEnv(name, env.ID); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "saved query not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error deleting saved query", http.StatusInternalServerError, err) + return + } + h.AuditLog.SavedQueryAction(ctx[ctxUser], "delete "+name, strings.Split(r.RemoteAddr, ":")[0], env.ID) + log.Debug().Msgf("Deleted saved query %s in env %s", name, env.UUID) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiGenericResponse{Message: fmt.Sprintf("saved query %s deleted", name)}) +} diff --git a/cmd/api/handlers/settings.go b/cmd/api/handlers/settings.go index 985fbabd..f2baa8f0 100644 --- a/cmd/api/handlers/settings.go +++ b/cmd/api/handlers/settings.go @@ -95,7 +95,7 @@ func (h *HandlersApi) SettingsServiceEnvHandler(w http.ResponseWriter, r *http.R return } // Get environment by name - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "environment not found", http.StatusNotFound, err) @@ -110,9 +110,9 @@ func (h *HandlersApi) SettingsServiceEnvHandler(w http.ResponseWriter, r *http.R apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get settings scoped to THIS env. Previously this passed - // NoEnvironmentID and silently returned global settings, which let an - // env-X admin read another env's values as a side-channel via the + // Get settings scoped to THIS env. Was previously passing + // NoEnvironmentID and silently returning global settings, which let + // an env-X admin read another env's values as a side-channel via the // env-scoped route. serviceSettings, err := h.Settings.RetrieveValues(service, false, env.ID) if err != nil { @@ -184,7 +184,7 @@ func (h *HandlersApi) SettingsServiceEnvJSONHandler(w http.ResponseWriter, r *ht return } // Get environment by name - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "environment not found", http.StatusNotFound, err) diff --git a/cmd/api/handlers/settings_patch.go b/cmd/api/handlers/settings_patch.go new file mode 100644 index 00000000..69336813 --- /dev/null +++ b/cmd/api/handlers/settings_patch.go @@ -0,0 +1,111 @@ +package handlers + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/types" + "github.com/jmpsec/osctrl/pkg/users" + "github.com/jmpsec/osctrl/pkg/utils" + "github.com/rs/zerolog/log" + "gorm.io/gorm" +) + +// SettingPatchHandler — PATCH /api/v1/settings/{service}/{name} +// +// Body shape (one of String, Boolean, Integer): +// +// { "string": "value" } +// { "boolean": true } +// { "integer": 42 } +// +// The handler reads the existing setting first to determine its type, then +// applies the matching typed setter. Mismatched payloads return 400. The +// setting must already exist (creation is the legacy admin's job); a missing +// setting → 404. Audit-log on success only. +func (h *HandlersApi) SettingPatchHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + service := r.PathValue("service") + if service == "" { + apiErrorResponse(w, "missing service", http.StatusBadRequest, nil) + return + } + if !h.Settings.VerifyService(service) { + apiErrorResponse(w, "invalid service", http.StatusBadRequest, nil) + return + } + name := r.PathValue("name") + if name == "" { + apiErrorResponse(w, "missing name", http.StatusBadRequest, nil) + return + } + + existing, err := h.Settings.RetrieveValue(service, name, settings.NoEnvironmentID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "setting not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error reading setting", http.StatusInternalServerError, err) + return + } + + var body types.SettingPatchRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing PATCH body", http.StatusBadRequest, err) + return + } + + switch existing.Type { + case settings.TypeBoolean: + if body.Boolean == nil { + apiErrorResponse(w, "setting is boolean — provide `boolean` in body", http.StatusBadRequest, nil) + return + } + if err := h.Settings.SetBoolean(*body.Boolean, service, name, settings.NoEnvironmentID); err != nil { + apiErrorResponse(w, "error updating setting", http.StatusInternalServerError, err) + return + } + case settings.TypeInteger: + if body.Integer == nil { + apiErrorResponse(w, "setting is integer — provide `integer` in body", http.StatusBadRequest, nil) + return + } + if err := h.Settings.SetInteger(*body.Integer, service, name, settings.NoEnvironmentID); err != nil { + apiErrorResponse(w, "error updating setting", http.StatusInternalServerError, err) + return + } + case settings.TypeString: + if body.String == nil { + apiErrorResponse(w, "setting is string — provide `string` in body", http.StatusBadRequest, nil) + return + } + if err := h.Settings.SetString(*body.String, service, name, existing.JSON, settings.NoEnvironmentID); err != nil { + apiErrorResponse(w, "error updating setting", http.StatusInternalServerError, err) + return + } + default: + apiErrorResponse(w, "unsupported setting type", http.StatusInternalServerError, nil) + return + } + + updated, err := h.Settings.RetrieveValue(service, name, settings.NoEnvironmentID) + if err != nil { + apiErrorResponse(w, "error reading updated setting", http.StatusInternalServerError, err) + return + } + h.AuditLog.SettingsAction(ctx[ctxUser], fmt.Sprintf("patch %s/%s", service, name), strings.Split(r.RemoteAddr, ":")[0]) + log.Debug().Msgf("Patched setting %s/%s", service, name) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, updated) +} diff --git a/cmd/api/handlers/stats.go b/cmd/api/handlers/stats.go new file mode 100644 index 00000000..800b0447 --- /dev/null +++ b/cmd/api/handlers/stats.go @@ -0,0 +1,539 @@ +package handlers + +import ( + "fmt" + "net/http" + "sort" + "strings" + "time" + + "github.com/jmpsec/osctrl/pkg/auditlog" + "github.com/jmpsec/osctrl/pkg/dbutil" + "github.com/jmpsec/osctrl/pkg/logging" + "github.com/jmpsec/osctrl/pkg/nodes" + "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/users" + "github.com/jmpsec/osctrl/pkg/utils" + "github.com/rs/zerolog/log" +) + +// EnvStats is one row in the per-env breakdown returned by /api/v1/stats. +type EnvStats struct { + UUID string `json:"uuid"` + Name string `json:"name"` + Active int64 `json:"active"` + Inactive int64 `json:"inactive"` + Total int64 `json:"total"` + ActiveQueries int `json:"active_queries"` + ActiveCarves int `json:"active_carves"` + // PlatformCounts buckets the env's nodes by OS family (linux / darwin / + // windows / other). Drives the Nodes-table QuickFilters chip row. Counts + // are total (active + inactive), since the filter chip lists all nodes + // of that platform regardless of staleness — the Active/Inactive toggle + // is independent. + PlatformCounts nodes.PlatformCounts `json:"platform_counts"` +} + +// StatsResponse is the canonical /api/v1/stats shape consumed by the dashboard. +type StatsResponse struct { + // Cross-env totals (the user's allowed envs only). + TotalNodes int64 `json:"total_nodes"` + ActiveNodes int64 `json:"active_nodes"` + InactiveNodes int64 `json:"inactive_nodes"` + // TotalActiveQueries counts standard query-type active queries (excludes carves). + TotalActiveQueries int `json:"total_active_queries"` + // TotalActiveCarves counts active carve-type queries. + TotalActiveCarves int `json:"total_active_carves"` + // Cross-env platform breakdown — sum of every accessible env's PlatformCounts. + PlatformCounts nodes.PlatformCounts `json:"platform_counts"` + + // Per-env breakdown, in stable alphabetical order by name. + Environments []EnvStats `json:"environments"` +} + +// StatsHandler returns cross-env totals + per-env counts, filtered to the +// envs the calling user has UserLevel access to. Used by the SPA dashboard. +// +// No query params. The response is small (one entry per accessible env) and +// cacheable for 30s on the client (Cache-Control: private, max-age=30). +// +// NOTE on query/carve counting: +// - GetActive(envID) returns ALL active rows regardless of type (union). +// - To avoid double-counting we call GetQueries("active", envID) for +// standard queries and GetCarves("active", envID) for carves separately. +// - Unit test for this handler is deferred: the underlying pkg/queries +// functions are exercised by existing tests in pkg/queries; a full +// integration test would require DB fixture setup that is out of scope +// for Track 2. +func (h *HandlersApi) StatsHandler(w http.ResponseWriter, r *http.Request) { + ctxVal := r.Context().Value(ContextKey(contextAPI)) + if ctxVal == nil { + apiErrorResponse(w, "missing auth context", http.StatusUnauthorized, nil) + return + } + ctx := ctxVal.(ContextValue) + user := ctx[ctxUser] + + allEnvs, err := h.Envs.All() + if err != nil { + apiErrorResponse(w, "failed to load environments", http.StatusInternalServerError, err) + return + } + + hours := h.Settings.InactiveHours(settings.NoEnvironmentID) + out := StatsResponse{Environments: make([]EnvStats, 0, len(allEnvs))} + + for _, e := range allEnvs { + // Filter to envs the user can actually see. + if !h.Users.CheckPermissions(user, users.UserLevel, e.UUID) { + continue + } + + ns, err := h.Nodes.GetStatsByEnv(e.Name, hours) + if err != nil { + log.Warn().Err(err).Str("env", e.Name).Msg("stats: failed to get node stats, skipping env") + continue + } + + // Per-env platform counts (linux / darwin / windows / other) for the + // SPA's filter chips. We don't fail the whole env on a count error; + // if the GROUP BY fails the env still gets a row, just with zeros in + // PlatformCounts. The SPA renders the chips as "0" rather than missing. + platCounts, err := h.Nodes.GetPlatformCountsByEnv(e.Name) + if err != nil { + log.Warn().Err(err).Str("env", e.Name).Msg("stats: failed to get platform counts, defaulting to zeros") + } + + // Use type-specific methods to avoid double-counting: + // GetQueries returns StandardQueryType active items only. + // GetCarves returns CarveQueryType active items only. + activeQ, err := h.Queries.GetQueries(queries.TargetActive, e.ID) + if err != nil { + log.Warn().Err(err).Str("env", e.Name).Msg("stats: failed to count active queries, skipping env") + continue + } + activeC, err := h.Queries.GetCarves(queries.TargetActive, e.ID) + if err != nil { + log.Warn().Err(err).Str("env", e.Name).Msg("stats: failed to count active carves, skipping env") + continue + } + + row := EnvStats{ + UUID: e.UUID, + Name: e.Name, + Active: ns.Active, + Inactive: ns.Inactive, + Total: ns.Total, + ActiveQueries: len(activeQ), + ActiveCarves: len(activeC), + PlatformCounts: platCounts, + } + out.Environments = append(out.Environments, row) + out.ActiveNodes += ns.Active + out.InactiveNodes += ns.Inactive + out.TotalNodes += ns.Total + out.TotalActiveQueries += len(activeQ) + out.TotalActiveCarves += len(activeC) + // Aggregate cross-env platform totals. + out.PlatformCounts.Linux += platCounts.Linux + out.PlatformCounts.Darwin += platCounts.Darwin + out.PlatformCounts.Windows += platCounts.Windows + out.PlatformCounts.Other += platCounts.Other + } + + // Stable alphabetical order by env name. + sort.Slice(out.Environments, func(i, j int) bool { + return out.Environments[i].Name < out.Environments[j].Name + }) + + w.Header().Set("Cache-Control", "private, max-age=30") + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, out) +} + +// ActivityBucket is one cell of the 24-hour activity heatmap. BucketStart is +// the start of the 15-minute window (UTC, RFC3339); the four counters are +// the audit-log entry counts that fell into that window for each category. +// +// Categories (audit log_type → category): +// - config ← Setting (8) + Environment (7) +// - query ← Query (4) +// - carve ← Carve (5) +// - enroll ← Node (3) — covers enroll, archive, deletion +type ActivityBucket struct { + BucketStart time.Time `json:"bucket_start"` + Config int `json:"config"` + Query int `json:"query"` + Carve int `json:"carve"` + Enroll int `json:"enroll"` +} + +// activityIntervalPresets maps the SPA's interval picker values to (hours, +// bucketSeconds). Bucket sizes are chosen so the cell count stays in the +// 36..96 range across the full picker — small enough to fit one row at +// 1280px, large enough that the heatmap still reads as a sparse density map. +// +// Adding a new preset: pick a bucketSeconds that divides hours*3600 evenly +// to avoid an under-filled trailing cell. +type activityPreset struct { + bucketSeconds int +} + +var activityIntervalPresets = map[string]activityPreset{ + "3h": {bucketSeconds: 5 * 60}, // 36 cells + "6h": {bucketSeconds: 5 * 60}, // 72 cells + "12h": {bucketSeconds: 10 * 60}, // 72 cells + "1d": {bucketSeconds: 15 * 60}, // 96 cells + "2d": {bucketSeconds: 30 * 60}, // 96 cells + "3d": {bucketSeconds: 45 * 60}, // 96 cells + "7d": {bucketSeconds: 2 * 3600}, // 84 cells +} + +var activityIntervalHours = map[string]int{ + "3h": 3, "6h": 6, "12h": 12, "1d": 24, "2d": 48, "3d": 72, "7d": 168, +} + +// EnvActivityHandler — GET /api/v1/stats/activity/{env}?interval=KEY +// +// Returns audit-log activity for one env over the requested interval, +// bucketed at a fixed size per interval (see activityIntervalPresets). +// `interval` accepts 3h / 6h / 12h / 1d / 2d / 3d / 7d (default 1d, falls +// back to 1d on any unknown value rather than 400ing — the SPA picker is +// the only allowed source). +// +// Buckets are emitted contiguously — empty windows return zero rows for +// that bucket — so the SPA can render the grid without densifying +// client-side. +func (h *HandlersApi) EnvActivityHandler(w http.ResponseWriter, r *http.Request) { + ctxVal := r.Context().Value(ContextKey(contextAPI)) + if ctxVal == nil { + apiErrorResponse(w, "missing auth context", http.StatusUnauthorized, nil) + return + } + ctx := ctxVal.(ContextValue) + user := ctx[ctxUser] + + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "error with environment", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + apiErrorResponse(w, "error getting environment", http.StatusNotFound, err) + return + } + if !h.Users.CheckPermissions(user, users.UserLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", user)) + return + } + + intervalKey := r.URL.Query().Get("interval") + preset, ok := activityIntervalPresets[intervalKey] + if !ok { + intervalKey = "1d" + preset = activityIntervalPresets["1d"] + } + hours := activityIntervalHours[intervalKey] + bucketSeconds := preset.bucketSeconds + totalSeconds := hours * 3600 + nBuckets := totalSeconds / bucketSeconds + + // Align the strip to the most-recent 15-min boundary so the rightmost + // column always represents "now" rather than a partial bucket. Avoids + // the visual confusion of an under-filled trailing cell. + now := time.Now().UTC() + endBucket := time.Unix((now.Unix()/int64(bucketSeconds))*int64(bucketSeconds), 0).UTC() + startBucket := endBucket.Add(-time.Duration(nBuckets-1) * time.Duration(bucketSeconds) * time.Second) + + rows, err := h.AuditLog.GetEnvActivityBucketed(env.ID, startBucket, bucketSeconds) + if err != nil { + apiErrorResponse(w, "failed to load activity", http.StatusInternalServerError, err) + return + } + + // Pre-allocate the contiguous bucket array so empty windows still ship a + // row. Indexing is by `(bucket_start - startUnix) / bucketSeconds`, + // floor-clamped to [0, nBuckets-1]. + startUnix := startBucket.Unix() + out := make([]ActivityBucket, nBuckets) + for i := range out { + out[i].BucketStart = startBucket.Add(time.Duration(i) * time.Duration(bucketSeconds) * time.Second) + } + for _, row := range rows { + idx := int((row.BucketStart - startUnix) / int64(bucketSeconds)) + if idx < 0 || idx >= nBuckets { + continue + } + switch row.LogType { + case auditlog.LogTypeSetting, auditlog.LogTypeEnvironment: + out[idx].Config += int(row.Cnt) + case auditlog.LogTypeQuery: + out[idx].Query += int(row.Cnt) + case auditlog.LogTypeCarve: + out[idx].Carve += int(row.Cnt) + case auditlog.LogTypeNode: + out[idx].Enroll += int(row.Cnt) + } + } + + w.Header().Set("Cache-Control", "private, max-age=30") + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, out) +} + +// NodeActivityBucket is one cell of the per-node 24h activity heatmap. +// Categories pivot from the env-scoped variant — node-scoped activity is +// about what THIS device has been doing, not what operators have done to +// the env. So: +// - status ← osquery_status_data row count (status logs received from this node) +// - result ← osquery_result_data row count (query results returned by this node) +// - query ← node_queries row count (distributed queries scheduled against this node) +// - carve ← carved_files row count (carves this node has produced) +// +// All four are joinable by node uuid (or numeric node id for node_queries). +type NodeActivityBucket struct { + BucketStart time.Time `json:"bucket_start"` + Status int `json:"status"` + Result int `json:"result"` + Query int `json:"query"` + Carve int `json:"carve"` +} + +// NodeActivityHandler — GET /api/v1/stats/activity/node/{env}/{uuid}?interval=KEY +// +// Per-node version of EnvActivityHandler. Same bucketing rules (see +// activityIntervalPresets). The four categories partition different DB +// tables (see NodeActivityBucket) keyed by the node's UUID — except +// node_queries which keys by numeric NodeID, looked up once from the +// resolved node. +func (h *HandlersApi) NodeActivityHandler(w http.ResponseWriter, r *http.Request) { + ctxVal := r.Context().Value(ContextKey(contextAPI)) + if ctxVal == nil { + apiErrorResponse(w, "missing auth context", http.StatusUnauthorized, nil) + return + } + ctx := ctxVal.(ContextValue) + user := ctx[ctxUser] + + envVar := r.PathValue("env") + uuidVar := r.PathValue("uuid") + if envVar == "" || uuidVar == "" { + apiErrorResponse(w, "env and uuid required", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + apiErrorResponse(w, "error getting environment", http.StatusNotFound, err) + return + } + if !h.Users.CheckPermissions(user, users.UserLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", user)) + return + } + // Resolve the node — gives us the numeric NodeID for the node_queries + // join and lets us reject probes for arbitrary UUIDs across tenants. + node, err := h.Nodes.GetByUUID(uuidVar) + if err != nil { + apiErrorResponse(w, "node not found", http.StatusNotFound, err) + return + } + if !strings.EqualFold(node.Environment, env.Name) { + apiErrorResponse(w, "node not in environment", http.StatusForbidden, nil) + return + } + + intervalKey := r.URL.Query().Get("interval") + preset, ok := activityIntervalPresets[intervalKey] + if !ok { + intervalKey = "1d" + preset = activityIntervalPresets["1d"] + } + hours := activityIntervalHours[intervalKey] + bucketSeconds := preset.bucketSeconds + totalSeconds := hours * 3600 + nBuckets := totalSeconds / bucketSeconds + + now := time.Now().UTC() + endBucket := time.Unix((now.Unix()/int64(bucketSeconds))*int64(bucketSeconds), 0).UTC() + startBucket := endBucket.Add(-time.Duration(nBuckets-1) * time.Duration(bucketSeconds) * time.Second) + + out := h.computeNodeActivityForNode(env.Name, node.UUID, node.ID, startBucket, bucketSeconds, nBuckets) + w.Header().Set("Cache-Control", "private, max-age=30") + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, out) +} + +// computeNodeActivityForNode runs the 4-table bucketed-count pipeline for +// one node and returns the dense bucket array. Shared by both +// NodeActivityHandler and NodeActivityBatchHandler so the bucketing rules +// stay in one place. +// +// Each category issues a single SQL GROUP BY rather than plucking every +// CreatedAt — at 50k+ nodes a chatty status_data table would otherwise +// stream tens of thousands of timestamps per Nodes page row. +// Fail-soft per category: a single-table error still renders the others. +func (h *HandlersApi) computeNodeActivityForNode( + envName string, + nodeUUID string, + nodeID uint, + startBucket time.Time, + bucketSeconds int, + nBuckets int, +) []NodeActivityBucket { + startUnix := startBucket.Unix() + + statusRows, err := logging.GetNodeStatusBucketed(h.DB, envName, nodeUUID, startBucket, bucketSeconds) + if err != nil { + log.Warn().Err(err).Str("node", nodeUUID).Msg("node-activity: status bucketed failed") + } + resultRows, err := logging.GetNodeResultBucketed(h.DB, envName, nodeUUID, startBucket, bucketSeconds) + if err != nil { + log.Warn().Err(err).Str("node", nodeUUID).Msg("node-activity: result bucketed failed") + } + queryRows, err := h.Queries.GetNodeQueryBucketed(nodeID, startBucket, bucketSeconds) + if err != nil { + log.Warn().Err(err).Str("node", nodeUUID).Msg("node-activity: node-query bucketed failed") + } + carveRows, err := h.Carves.GetNodeCarveBucketed(nodeUUID, startBucket, bucketSeconds) + if err != nil { + log.Warn().Err(err).Str("node", nodeUUID).Msg("node-activity: carve bucketed failed") + } + + statusDense := dbutil.DensifyBuckets(statusRows, startUnix, bucketSeconds, nBuckets) + resultDense := dbutil.DensifyBuckets(resultRows, startUnix, bucketSeconds, nBuckets) + queryDense := dbutil.DensifyBuckets(queryRows, startUnix, bucketSeconds, nBuckets) + carveDense := dbutil.DensifyBuckets(carveRows, startUnix, bucketSeconds, nBuckets) + + out := make([]NodeActivityBucket, nBuckets) + for i := range out { + out[i].BucketStart = startBucket.Add(time.Duration(i) * time.Duration(bucketSeconds) * time.Second) + out[i].Status = int(statusDense[i]) + out[i].Result = int(resultDense[i]) + out[i].Query = int(queryDense[i]) + out[i].Carve = int(carveDense[i]) + } + return out +} + +// NodeActivityBatchHandler — GET /api/v1/stats/activity/node-batch/{env}?uuids=A,B,C&interval=KEY +// +// Returns activity buckets for up to 100 nodes in one call. The response is +// a map keyed by node UUID so the SPA can render a sparkline per row in the +// Nodes table without firing N parallel requests. +// +// Cap is 100 to bound the per-request DB load — each node still requires 4 +// timestamp queries. The SPA's pagination is already <=500 page size; for +// pages above 100 nodes the SPA fans out 2-3 batch requests instead. +// +// Unknown / unauthorized UUIDs are silently omitted from the response +// (they're treated as "no data"), not 404'd — that lets a single bad UUID +// in the list not break the whole page render. +func (h *HandlersApi) NodeActivityBatchHandler(w http.ResponseWriter, r *http.Request) { + ctxVal := r.Context().Value(ContextKey(contextAPI)) + if ctxVal == nil { + apiErrorResponse(w, "missing auth context", http.StatusUnauthorized, nil) + return + } + ctx := ctxVal.(ContextValue) + user := ctx[ctxUser] + + envVar := r.PathValue("env") + if envVar == "" { + apiErrorResponse(w, "env required", http.StatusBadRequest, nil) + return + } + env, err := h.Envs.Get(envVar) + if err != nil { + apiErrorResponse(w, "error getting environment", http.StatusNotFound, err) + return + } + if !h.Users.CheckPermissions(user, users.UserLevel, env.UUID) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", user)) + return + } + + uuidsParam := strings.TrimSpace(r.URL.Query().Get("uuids")) + if uuidsParam == "" { + // Empty request → empty response. Avoids the page from breaking when + // the SPA's `nodes` query returns 0 rows (zero-length CSV). + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, map[string][]NodeActivityBucket{}) + return + } + rawUUIDs := strings.Split(uuidsParam, ",") + const maxBatch = 100 + if len(rawUUIDs) > maxBatch { + rawUUIDs = rawUUIDs[:maxBatch] + } + // Dedupe + normalize (upper-case, like the DB stores them). + seen := make(map[string]struct{}, len(rawUUIDs)) + uuids := rawUUIDs[:0] + for _, u := range rawUUIDs { + u = strings.ToUpper(strings.TrimSpace(u)) + if u == "" { + continue + } + if _, dup := seen[u]; dup { + continue + } + seen[u] = struct{}{} + uuids = append(uuids, u) + } + + intervalKey := r.URL.Query().Get("interval") + preset, ok := activityIntervalPresets[intervalKey] + if !ok { + intervalKey = "1d" + preset = activityIntervalPresets["1d"] + } + hours := activityIntervalHours[intervalKey] + bucketSeconds := preset.bucketSeconds + totalSeconds := hours * 3600 + nBuckets := totalSeconds / bucketSeconds + + now := time.Now().UTC() + endBucket := time.Unix((now.Unix()/int64(bucketSeconds))*int64(bucketSeconds), 0).UTC() + startBucket := endBucket.Add(-time.Duration(nBuckets-1) * time.Duration(bucketSeconds) * time.Second) + + out := make(map[string][]NodeActivityBucket, len(uuids)) + for _, u := range uuids { + // Per-uuid resolution. A miss is logged-but-skipped rather than + // failed-the-whole-batch — see handler comment for rationale. + node, err := h.Nodes.GetByUUID(u) + if err != nil { + log.Debug().Err(err).Str("node", u).Msg("node-activity-batch: uuid not found, skipping") + continue + } + if !strings.EqualFold(node.Environment, env.Name) { + log.Debug().Str("node", u).Msg("node-activity-batch: uuid not in env, skipping") + continue + } + out[node.UUID] = h.computeNodeActivityForNode(env.Name, node.UUID, node.ID, startBucket, bucketSeconds, nBuckets) + } + + w.Header().Set("Cache-Control", "private, max-age=30") + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, out) +} + +// OsqueryVersionsHandler — GET /api/v1/stats/osquery-versions. +// +// Returns fleet-wide osquery agent version breakdown for the dashboard's +// "fleet hygiene" panel. Operators use this to spot stale agents that need +// upgrading. Cross-env (no env filter); the dashboard already surfaces the +// per-env breakdown in its env tiles. +// +// Counts include both active and inactive nodes — a node sitting at an old +// osquery version is still "stale" even if it's offline today, because once +// it comes back online it'll come back stale. +func (h *HandlersApi) OsqueryVersionsHandler(w http.ResponseWriter, r *http.Request) { + ctxVal := r.Context().Value(ContextKey(contextAPI)) + if ctxVal == nil { + apiErrorResponse(w, "missing auth context", http.StatusUnauthorized, nil) + return + } + rows, err := h.Nodes.GetOsqueryVersionCounts() + if err != nil { + apiErrorResponse(w, "failed to load osquery versions", http.StatusInternalServerError, err) + return + } + w.Header().Set("Cache-Control", "private, max-age=60") + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, rows) +} diff --git a/cmd/api/handlers/stats_test.go b/cmd/api/handlers/stats_test.go new file mode 100644 index 00000000..b374e88e --- /dev/null +++ b/cmd/api/handlers/stats_test.go @@ -0,0 +1,94 @@ +package handlers + +import ( + "encoding/json" + "testing" +) + +// TestStatsResponseShape verifies the JSON tags on the response types are +// snake_case and match the OpenAPI schema field names. This catches regressions +// where a field rename in Go doesn't propagate to the JSON output shape. +// +// Full integration tests (DB-backed) are deferred: the underlying +// pkg/nodes.GetStatsByEnv and pkg/queries.GetQueries/GetCarves are covered by +// their own package tests. A handler-level integration test would require +// substantial DB fixturing that is out of scope for Track 2. +func TestStatsResponseShape(t *testing.T) { + resp := StatsResponse{ + TotalNodes: 10, + ActiveNodes: 7, + InactiveNodes: 3, + TotalActiveQueries: 2, + TotalActiveCarves: 1, + Environments: []EnvStats{ + { + UUID: "env-uuid-1", + Name: "prod", + Active: 5, + Inactive: 2, + Total: 7, + ActiveQueries: 1, + ActiveCarves: 0, + }, + }, + } + + b, err := json.Marshal(resp) + if err != nil { + t.Fatalf("json.Marshal(StatsResponse): %v", err) + } + + var m map[string]interface{} + if err := json.Unmarshal(b, &m); err != nil { + t.Fatalf("json.Unmarshal: %v", err) + } + + // Verify top-level snake_case field names. + topLevel := []string{ + "total_nodes", + "active_nodes", + "inactive_nodes", + "total_active_queries", + "total_active_carves", + "platform_counts", + "environments", + } + for _, key := range topLevel { + if _, ok := m[key]; !ok { + t.Errorf("StatsResponse JSON missing field %q", key) + } + } + + // Verify per-env field names in the first environments entry. + envs, ok := m["environments"].([]interface{}) + if !ok || len(envs) == 0 { + t.Fatal("StatsResponse.environments is empty or wrong type") + } + envMap, ok := envs[0].(map[string]interface{}) + if !ok { + t.Fatal("environments[0] is not a JSON object") + } + envLevel := []string{ + "uuid", + "name", + "active", + "inactive", + "total", + "active_queries", + "active_carves", + "platform_counts", + } + for _, key := range envLevel { + if _, ok := envMap[key]; !ok { + t.Errorf("EnvStats JSON missing field %q", key) + } + } + + // Verify numeric totals round-trip correctly. + if got := m["total_nodes"]; got != float64(10) { + t.Errorf("total_nodes = %v, want 10", got) + } + if got := m["active_nodes"]; got != float64(7) { + t.Errorf("active_nodes = %v, want 7", got) + } +} diff --git a/cmd/api/handlers/tags.go b/cmd/api/handlers/tags.go index 552045a2..801aaba2 100644 --- a/cmd/api/handlers/tags.go +++ b/cmd/api/handlers/tags.go @@ -38,26 +38,25 @@ func (h *HandlersApi) AllTagsHandler(w http.ResponseWriter, r *http.Request) { utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, tags) } -// TagEnvHandler - GET Handler to return one tag for one environment as JSON +// TagEnvHandler - GET Handler to return one tag for one environment as JSON. +// Permission is scoped to env.UUID admin so non-super operators with admin +// rights on this specific environment can view its tags. func (h *HandlersApi) TagEnvHandler(w http.ResponseWriter, r *http.Request) { // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) } - // Extract environment envVar := r.PathValue("env") if envVar == "" { apiErrorResponse(w, "error getting environment", http.StatusBadRequest, nil) return } - // Extract tag name tagVar := r.PathValue("name") if tagVar == "" { apiErrorResponse(w, "error getting tag name", http.StatusBadRequest, nil) return } - // Get environment by UUID - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "environment not found", http.StatusNotFound, err) @@ -66,38 +65,33 @@ func (h *HandlersApi) TagEnvHandler(w http.ResponseWriter, r *http.Request) { } return } - // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) - if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get tag exist, tag := h.Tags.ExistsGet(tagVar, env.ID) if !exist { - apiErrorResponse(w, "error getting tag", http.StatusInternalServerError, err) + apiErrorResponse(w, "tag not found", http.StatusNotFound, nil) return } - // Serialize and serve JSON log.Debug().Msg("Returned tag") h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, tag) } -// TagsEnvHandler - GET Handler to return tags for one environment as JSON +// TagsEnvHandler - GET Handler to return tags for one environment as JSON. +// Permission is scoped to env.UUID admin (see TagEnvHandler note). func (h *HandlersApi) TagsEnvHandler(w http.ResponseWriter, r *http.Request) { - // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) } - // Extract environment envVar := r.PathValue("env") if envVar == "" { apiErrorResponse(w, "error getting environment", http.StatusBadRequest, nil) return } - // Get environment by UUID - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "environment not found", http.StatusNotFound, err) @@ -106,38 +100,39 @@ func (h *HandlersApi) TagsEnvHandler(w http.ResponseWriter, r *http.Request) { } return } - // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) - if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Get tags - tags, err := h.Tags.GetByEnv(env.ID) + tagList, err := h.Tags.GetByEnv(env.ID) if err != nil { apiErrorResponse(w, "error getting tags", http.StatusInternalServerError, err) return } - // Serialize and serve JSON - log.Debug().Msgf("Returned %d tags", len(tags)) + // Empty list is a valid state — never return 404 on listing. + if tagList == nil { + tagList = []tags.AdminTag{} + } + log.Debug().Msgf("Returned %d tags", len(tagList)) h.AuditLog.Visit(ctx[ctxUser], r.URL.Path, strings.Split(r.RemoteAddr, ":")[0], env.ID) - utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, tags) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, tagList) } -// TagsActionHandler - POST Handler to create, update or delete tags +// TagsActionHandler - POST Handler to create / update / delete tags. The +// action arrives as a URL path segment (legacy contract retained because +// Track 6 doesn't introduce new tag routes); body validation surfaces 400 +// on parse error and 409 on duplicate-name conflicts. func (h *HandlersApi) TagsActionHandler(w http.ResponseWriter, r *http.Request) { - // Debug HTTP if enabled if h.DebugHTTPConfig.EnableHTTP { utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) } - // Extract environment envVar := r.PathValue("env") if envVar == "" { apiErrorResponse(w, "error getting environment", http.StatusBadRequest, nil) return } - // Get environment by UUID - env, err := h.Envs.GetByUUID(envVar) + env, err := h.Envs.Get(envVar) if err != nil { if err.Error() == "record not found" { apiErrorResponse(w, "environment not found", http.StatusNotFound, err) @@ -146,37 +141,42 @@ func (h *HandlersApi) TagsActionHandler(w http.ResponseWriter, r *http.Request) } return } - // Get context data and check access ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) - if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, env.UUID) { apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) return } - // Extract action actionVar := r.PathValue("action") if actionVar == "" { apiErrorResponse(w, "error getting action", http.StatusBadRequest, nil) return } var t types.ApiTagsRequest - // Parse request JSON body if err := json.NewDecoder(r.Body).Decode(&t); err != nil { - apiErrorResponse(w, "error parsing POST body", http.StatusInternalServerError, err) + apiErrorResponse(w, "error parsing POST body", http.StatusBadRequest, err) + return + } + if t.Name == "" { + apiErrorResponse(w, "tag name can not be empty", http.StatusBadRequest, nil) return } var returnData string switch actionVar { case tags.ActionAdd: if h.Tags.ExistsByEnv(t.Name, env.ID) { - apiErrorResponse(w, "error adding tag", http.StatusInternalServerError, fmt.Errorf("tag %s already exists", t.Name)) + apiErrorResponse(w, "tag with that name already exists in this environment", http.StatusConflict, nil) return } if err := h.Tags.NewTag(t.Name, t.Description, t.Color, t.Icon, ctx[ctxUser], env.ID, false, t.TagType, t.Custom); err != nil { - apiErrorResponse(w, "error with new tag", http.StatusInternalServerError, err) + apiErrorResponse(w, "error creating tag", http.StatusInternalServerError, err) return } returnData = "tag added successfully" case tags.ActionEdit: + if !h.Tags.ExistsByEnv(t.Name, env.ID) { + apiErrorResponse(w, "tag not found", http.StatusNotFound, nil) + return + } tag, err := h.Tags.Get(t.Name, env.ID) if err != nil { apiErrorResponse(w, "error getting tag", http.StatusInternalServerError, err) @@ -218,13 +218,19 @@ func (h *HandlersApi) TagsActionHandler(w http.ResponseWriter, r *http.Request) } returnData = "tag updated successfully" case tags.ActionRemove: + if !h.Tags.ExistsByEnv(t.Name, env.ID) { + apiErrorResponse(w, "tag not found", http.StatusNotFound, nil) + return + } if err := h.Tags.DeleteGet(t.Name, env.ID); err != nil { apiErrorResponse(w, "error removing tag", http.StatusInternalServerError, err) return } returnData = "tag removed successfully" + default: + apiErrorResponse(w, "invalid action", http.StatusBadRequest, nil) + return } - // Serialize and serve JSON log.Debug().Msgf("Returned [%s]", returnData) h.AuditLog.TagAction(ctx[ctxUser], actionVar+" tag "+t.Name, strings.Split(r.RemoteAddr, ":")[0], env.ID) utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiDataResponse{Data: returnData}) diff --git a/cmd/api/handlers/users_profile.go b/cmd/api/handlers/users_profile.go new file mode 100644 index 00000000..1da560ed --- /dev/null +++ b/cmd/api/handlers/users_profile.go @@ -0,0 +1,293 @@ +package handlers + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/jmpsec/osctrl/pkg/types" + "github.com/jmpsec/osctrl/pkg/users" + "github.com/jmpsec/osctrl/pkg/utils" + "github.com/rs/zerolog/log" + "gorm.io/gorm" +) + +const tokenRefreshDefaultHours = 24 + +// SetUserPermissionsHandler - POST /api/v1/users/{username}/permissions +// +// Body: { env_uuid, access: { user, query, carve, admin } }. Replaces the +// target user's per-env access rows. Returns 200 with the new EnvAccess. +// Requires super-admin (AdminLevel, NoEnvironment) — env-scoped admins can +// not grant permissions for their environment from this endpoint. +func (h *HandlersApi) SetUserPermissionsHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + if !h.Users.CheckPermissions(ctx[ctxUser], users.AdminLevel, users.NoEnvironment) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to use API by user %s", ctx[ctxUser])) + return + } + username := r.PathValue("username") + if username == "" { + apiErrorResponse(w, "missing username", http.StatusBadRequest, nil) + return + } + if !h.Users.Exists(username) { + apiErrorResponse(w, "user not found", http.StatusNotFound, nil) + return + } + + var body types.SetPermissionsRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing POST body", http.StatusBadRequest, err) + return + } + body.EnvUUID = strings.TrimSpace(body.EnvUUID) + if body.EnvUUID == "" { + apiErrorResponse(w, "env_uuid is required", http.StatusBadRequest, nil) + return + } + if _, err := h.Envs.GetByUUID(body.EnvUUID); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "environment not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, err) + return + } + + access := users.EnvAccess{ + User: body.Access.User, + Query: body.Access.Query, + Carve: body.Access.Carve, + Admin: body.Access.Admin, + } + + // Lockout guards. A super-admin cannot: + // 1. Self-demote — granting yourself a strict downgrade via this + // endpoint risks locking yourself out of further permission + // changes if no other super-admin exists. Force the operator + // to go through another super-admin. + // 2. Demote the LAST super-admin under any path. If admin=false + // and the target is the only AdminUser.Admin=true row, the + // system has no remaining super-admin and no one can manage + // users / envs / settings. Refuse with 409. + if username == ctx[ctxUser] && !access.Admin { + apiErrorResponse(w, "super-admins cannot self-demote via this endpoint", http.StatusForbidden, nil) + return + } + if !access.Admin && h.Users.IsAdmin(username) { + count, cerr := h.Users.CountAdmins() + if cerr != nil { + apiErrorResponse(w, "error checking admin count", http.StatusInternalServerError, cerr) + return + } + if count <= 1 { + apiErrorResponse(w, "refusing to demote the last super-admin", http.StatusConflict, fmt.Errorf("only %d admin user(s) remain", count)) + return + } + } + + if err := h.Users.ChangeAccess(username, body.EnvUUID, access); err != nil { + apiErrorResponse(w, "error setting permissions", http.StatusInternalServerError, err) + return + } + + h.AuditLog.Permissions(ctx[ctxUser], + fmt.Sprintf("set %s on env=%s u=%v q=%v c=%v a=%v", + username, body.EnvUUID, access.User, access.Query, access.Carve, access.Admin), + strings.Split(r.RemoteAddr, ":")[0], 0) + log.Debug().Msgf("permissions updated for user %s on env %s", username, body.EnvUUID) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, body.Access) +} + +// RefreshUserTokenHandler - POST /api/v1/users/{username}/token/refresh +// +// Generates a new JWT for the target user, persists it as the user's +// APIToken (invalidating the previous token), and returns the new token + +// expiry. Requires super-admin OR the request author asking for their own +// token. Audit-logged on success. +func (h *HandlersApi) RefreshUserTokenHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + username := r.PathValue("username") + if username == "" { + apiErrorResponse(w, "missing username", http.StatusBadRequest, nil) + return + } + requester := ctx[ctxUser] + isSelf := username == requester + if !isSelf && !h.Users.CheckPermissions(requester, users.AdminLevel, users.NoEnvironment) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to refresh token for %s by %s", username, requester)) + return + } + if !h.Users.Exists(username) { + apiErrorResponse(w, "user not found", http.StatusNotFound, nil) + return + } + + token, expires, err := h.Users.CreateToken(username, h.ServiceName, tokenRefreshDefaultHours) + if err != nil { + apiErrorResponse(w, "error creating token", http.StatusInternalServerError, err) + return + } + if err := h.Users.UpdateToken(username, token, expires); err != nil { + apiErrorResponse(w, "error persisting token", http.StatusInternalServerError, err) + return + } + h.AuditLog.NewToken(username, strings.Split(r.RemoteAddr, ":")[0]) + log.Debug().Msgf("refreshed API token for %s (requested by %s)", username, requester) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.TokenResponse{Token: token, Expires: expires}) +} + +// DeleteUserTokenHandler - DELETE /api/v1/users/{username}/token +// +// Clears the user's APIToken so any existing JWT for them stops working. +// Requires super-admin OR the user themselves. +func (h *HandlersApi) DeleteUserTokenHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + username := r.PathValue("username") + if username == "" { + apiErrorResponse(w, "missing username", http.StatusBadRequest, nil) + return + } + requester := ctx[ctxUser] + isSelf := username == requester + if !isSelf && !h.Users.CheckPermissions(requester, users.AdminLevel, users.NoEnvironment) { + apiErrorResponse(w, "no access", http.StatusForbidden, fmt.Errorf("attempt to delete token for %s by %s", username, requester)) + return + } + if err := h.Users.ClearToken(username); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + apiErrorResponse(w, "user not found", http.StatusNotFound, err) + return + } + apiErrorResponse(w, "error clearing token", http.StatusInternalServerError, err) + return + } + h.AuditLog.UserAction(requester, "deleted token for "+username, strings.Split(r.RemoteAddr, ":")[0]) + log.Debug().Msgf("deleted API token for %s (requested by %s)", username, requester) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiGenericResponse{Message: "token deleted"}) +} + +// MeHandler - GET /api/v1/users/me +// +// Returns the currently authenticated user's profile (sans password hash +// and API token). Useful for the SPA's Profile page. +func (h *HandlersApi) MeHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + requester := ctx[ctxUser] + user, err := h.Users.Get(requester) + if err != nil { + apiErrorResponse(w, "error getting user", http.StatusInternalServerError, err) + return + } + resp := types.UserMeResponse{ + Username: user.Username, + Email: user.Email, + Fullname: user.Fullname, + Admin: user.Admin, + Service: user.Service, + UUID: user.UUID, + TokenExpire: user.TokenExpire, + LastAccess: user.LastAccess, + } + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, resp) +} + +// MePatchHandler - PATCH /api/v1/users/me +// +// Updates email and/or fullname for the currently authenticated user. Sends +// each empty field through unchanged. Returns the updated profile. +func (h *HandlersApi) MePatchHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + requester := ctx[ctxUser] + var body types.UserMePatchRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing PATCH body", http.StatusBadRequest, err) + return + } + body.Email = strings.TrimSpace(body.Email) + body.Fullname = strings.TrimSpace(body.Fullname) + + if body.Email != "" { + if err := h.Users.ChangeEmail(requester, body.Email); err != nil { + apiErrorResponse(w, "error updating email", http.StatusInternalServerError, err) + return + } + } + if body.Fullname != "" { + if err := h.Users.ChangeFullname(requester, body.Fullname); err != nil { + apiErrorResponse(w, "error updating fullname", http.StatusInternalServerError, err) + return + } + } + + user, err := h.Users.Get(requester) + if err != nil { + apiErrorResponse(w, "error fetching updated user", http.StatusInternalServerError, err) + return + } + h.AuditLog.UserAction(requester, "updated own profile", strings.Split(r.RemoteAddr, ":")[0]) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.UserMeResponse{ + Username: user.Username, + Email: user.Email, + Fullname: user.Fullname, + Admin: user.Admin, + Service: user.Service, + UUID: user.UUID, + TokenExpire: user.TokenExpire, + LastAccess: user.LastAccess, + }) +} + +// MePasswordHandler - POST /api/v1/users/me/password +// +// Changes the currently authenticated user's password. Verifies the +// current password (bcrypt) before persisting the new hash. +func (h *HandlersApi) MePasswordHandler(w http.ResponseWriter, r *http.Request) { + if h.DebugHTTPConfig.EnableHTTP { + utils.DebugHTTPDump(h.DebugHTTP, r, h.DebugHTTPConfig.ShowBody) + } + ctx := r.Context().Value(ContextKey(contextAPI)).(ContextValue) + requester := ctx[ctxUser] + + var body types.PasswordChangeRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + apiErrorResponse(w, "error parsing POST body", http.StatusBadRequest, err) + return + } + if body.CurrentPassword == "" || body.NewPassword == "" { + apiErrorResponse(w, "current_password and new_password are required", http.StatusBadRequest, nil) + return + } + if len(body.NewPassword) < 8 { + apiErrorResponse(w, "new_password must be at least 8 characters", http.StatusBadRequest, nil) + return + } + if ok, _ := h.Users.CheckLoginCredentials(requester, body.CurrentPassword); !ok { + apiErrorResponse(w, "current password is incorrect", http.StatusForbidden, nil) + return + } + if err := h.Users.ChangePassword(requester, body.NewPassword); err != nil { + apiErrorResponse(w, "error changing password", http.StatusInternalServerError, err) + return + } + h.AuditLog.UserAction(requester, "changed own password", strings.Split(r.RemoteAddr, ":")[0]) + utils.HTTPResponse(w, utils.JSONApplicationUTF8, http.StatusOK, types.ApiGenericResponse{Message: "password changed"}) +} diff --git a/cmd/api/main.go b/cmd/api/main.go index 231f7e3e..43a46c10 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -20,10 +20,12 @@ import ( "github.com/jmpsec/osctrl/pkg/environments" "github.com/jmpsec/osctrl/pkg/logging" "github.com/jmpsec/osctrl/pkg/nodes" + "github.com/jmpsec/osctrl/pkg/osquery" "github.com/jmpsec/osctrl/pkg/queries" "github.com/jmpsec/osctrl/pkg/ratelimit" "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/tags" + "github.com/jmpsec/osctrl/pkg/types" "github.com/jmpsec/osctrl/pkg/users" "github.com/jmpsec/osctrl/pkg/utils" "github.com/jmpsec/osctrl/pkg/version" @@ -74,6 +76,8 @@ const ( apiNodesPath = "/nodes" // API queries path apiQueriesPath = "/queries" + // API saved queries path + apiSavedQueriesPath = "/saved-queries" // API users path apiUsersPath = "/users" // API all queries path @@ -90,6 +94,12 @@ const ( apiSettingsPath = "/settings" // API audit logs path apiAuditLogsPath = "/audit-logs" + // API logs path + apiLogsPath = "/logs" + // API stats path + apiStatsPath = "/stats" + // API osquery path + apiOsqueryPath = "/osquery" ) // Global variables @@ -109,8 +119,9 @@ var ( flags []cli.Flag serviceConfiguration config.APIConfiguration // FIXME this struct is temporary until we refactor to write settings to the DB - flagParams *config.ServiceParameters - auditLog *auditlog.AuditLogManager + flagParams *config.ServiceParameters + auditLog *auditlog.AuditLogManager + osqueryTables []types.OsqueryTable ) // Valid values for auth and logging in configuration @@ -291,6 +302,15 @@ func osctrlAPIService() { if err != nil { log.Fatal().Msgf("Error initializing audit log manager - %v", err) } + // Load osquery tables schema (best-effort; an empty slice is fine if the file doesn't exist) + if flagParams.Osquery.TablesFile != "" { + log.Info().Msgf("Loading osquery tables from %s", flagParams.Osquery.TablesFile) + osqueryTables, err = osquery.LoadTables(flagParams.Osquery.TablesFile) + if err != nil { + log.Warn().Msgf("Failed to load osquery tables: %v", err) + osqueryTables = []types.OsqueryTable{} + } + } // Initialize Admin handlers before router log.Info().Msg("Initializing handlers") handlersApi = handlers.CreateHandlersApi( @@ -307,6 +327,7 @@ func osctrlAPIService() { handlers.WithName(serviceName), handlers.WithAuditLog(auditLog), handlers.WithDebugHTTP(flagParams.Debug), + handlers.WithOsqueryTables(osqueryTables), handlers.WithOsqueryValues(*flagParams.Osquery), ) @@ -336,6 +357,18 @@ func osctrlAPIService() { handlersApi.AuditLog.FailedLogin("", utils.GetIP(r), "rate limit exceeded") }) muxAPI.Handle("POST "+_apiPath(apiLoginPath)+"/{env}", loginRateLimit(http.HandlerFunc(handlersApi.LoginHandler))) + // Pre-auth env list so the SPA login screen can offer a dropdown instead + // of a free-text field. The handler exposes only (uuid, name) — no + // secrets — and shares the same per-IP rate limiter as POST /login so the + // endpoint can't be turned into a higher-throughput env-enumeration probe. + muxAPI.Handle("GET "+_apiPath(apiLoginPath)+"/environments", loginRateLimit(http.HandlerFunc(handlersApi.LoginEnvironmentsHandler))) + // Pre-auth starter-sample endpoints. The SPA reads these to populate the + // queries/new and carves/new template rows. Samples are static read-only + // data shipped with the binary, not tenant- or env-scoped — same posture + // as /login/environments. Shared per-IP rate limiter blocks low-effort + // scanning probes. + muxAPI.Handle("GET "+_apiPath(apiQueriesPath)+"/samples", loginRateLimit(http.HandlerFunc(handlersApi.QuerySamplesHandler))) + muxAPI.Handle("GET "+_apiPath(apiCarvesPath)+"/samples", loginRateLimit(http.HandlerFunc(handlersApi.CarveSamplesHandler))) // ///////////////////////// AUTHENTICATED // API: check auth muxAPI.Handle( @@ -362,6 +395,36 @@ func osctrlAPIService() { muxAPI.Handle( "POST "+_apiPath(apiNodesPath)+"/lookup", handlerAuthCheck(http.HandlerFunc(handlersApi.LookupNodeHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // API: paginated nodes — canonical SPA endpoint + muxAPI.Handle( + "GET "+_apiPath(apiNodesPath)+"/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.NodesPagedHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // API: node logs + muxAPI.Handle( + "GET "+_apiPath(apiLogsPath)+"/{type}/{env}/{uuid}", + handlerAuthCheck(http.HandlerFunc(handlersApi.NodeLogsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // API: cross-env dashboard stats + muxAPI.Handle( + "GET "+_apiPath(apiStatsPath), + handlerAuthCheck(http.HandlerFunc(handlersApi.StatsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // API: fleet-wide osquery version breakdown for dashboard's hygiene panel. + muxAPI.Handle( + "GET "+_apiPath(apiStatsPath)+"/osquery-versions", + handlerAuthCheck(http.HandlerFunc(handlersApi.OsqueryVersionsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // API: per-env activity heatmap (15-min audit-log buckets across N hours). + muxAPI.Handle( + "GET "+_apiPath(apiStatsPath)+"/activity/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.EnvActivityHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // API: per-node activity heatmap (status/result/query/carve buckets). + muxAPI.Handle( + "GET "+_apiPath(apiStatsPath)+"/activity/node/{env}/{uuid}", + handlerAuthCheck(http.HandlerFunc(handlersApi.NodeActivityHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // Batch variant — accepts ?uuids=a,b,c (up to 100). Returns a map keyed by + // uuid. Lets the Nodes table render a per-row sparkline without firing N + // parallel HTTP requests. + muxAPI.Handle( + "GET "+_apiPath(apiStatsPath)+"/activity/node-batch/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.NodeActivityBatchHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) // API: queries by environment if flagParams.Osquery.Query { muxAPI.Handle( @@ -379,13 +442,34 @@ func osctrlAPIService() { muxAPI.Handle( "GET "+_apiPath(apiQueriesPath)+"/{env}/results/{name}", handlerAuthCheck(http.HandlerFunc(handlersApi.QueryResultsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // CSV export for query results + muxAPI.Handle( + "GET "+_apiPath(apiQueriesPath)+"/{env}/results/csv/{name}", + handlerAuthCheck(http.HandlerFunc(handlersApi.QueryResultsCSVHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "GET "+_apiPath(apiAllQueriesPath+"/{env}"), handlerAuthCheck(http.HandlerFunc(handlersApi.AllQueriesShowHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "POST "+_apiPath(apiQueriesPath)+"/{env}/{action}/{name}", handlerAuthCheck(http.HandlerFunc(handlersApi.QueriesActionHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // API: saved queries (Track 4) + muxAPI.Handle( + "GET "+_apiPath(apiSavedQueriesPath)+"/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.SavedQueriesListHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "POST "+_apiPath(apiSavedQueriesPath)+"/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.SavedQueryCreateHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "PATCH "+_apiPath(apiSavedQueriesPath)+"/{env}/{name}", + handlerAuthCheck(http.HandlerFunc(handlersApi.SavedQueryUpdateHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "DELETE "+_apiPath(apiSavedQueriesPath)+"/{env}/{name}", + handlerAuthCheck(http.HandlerFunc(handlersApi.SavedQueryDeleteHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) } + // API: osquery schema tables (globally available to authenticated users) + muxAPI.Handle( + "GET "+_apiPath(apiOsqueryPath)+"/tables", + handlerAuthCheck(http.HandlerFunc(handlersApi.OsqueryTablesHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) // API: carves by environment if flagParams.Osquery.Carve { muxAPI.Handle( @@ -403,17 +487,38 @@ func osctrlAPIService() { muxAPI.Handle( "GET "+_apiPath(apiCarvesPath)+"/{env}/{name}", handlerAuthCheck(http.HandlerFunc(handlersApi.CarveShowHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "GET "+_apiPath(apiCarvesPath)+"/{env}/archive/{name}", + handlerAuthCheck(http.HandlerFunc(handlersApi.CarveArchiveHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "POST "+_apiPath(apiCarvesPath)+"/{env}/{action}/{name}", handlerAuthCheck(http.HandlerFunc(handlersApi.CarvesActionHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) } // API: users + muxAPI.Handle( + "GET "+_apiPath(apiUsersPath)+"/me", + handlerAuthCheck(http.HandlerFunc(handlersApi.MeHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "PATCH "+_apiPath(apiUsersPath)+"/me", + handlerAuthCheck(http.HandlerFunc(handlersApi.MePatchHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "POST "+_apiPath(apiUsersPath)+"/me/password", + handlerAuthCheck(http.HandlerFunc(handlersApi.MePasswordHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "GET "+_apiPath(apiUsersPath)+"/{username}", handlerAuthCheck(http.HandlerFunc(handlersApi.UserHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "GET "+_apiPath(apiUsersPath), handlerAuthCheck(http.HandlerFunc(handlersApi.UsersHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "POST "+_apiPath(apiUsersPath)+"/{username}/permissions", + handlerAuthCheck(http.HandlerFunc(handlersApi.SetUserPermissionsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "POST "+_apiPath(apiUsersPath)+"/{username}/token/refresh", + handlerAuthCheck(http.HandlerFunc(handlersApi.RefreshUserTokenHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "DELETE "+_apiPath(apiUsersPath)+"/{username}/token", + handlerAuthCheck(http.HandlerFunc(handlersApi.DeleteUserTokenHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( "POST "+_apiPath(apiUsersPath)+"/{username}/{action}", handlerAuthCheck(http.HandlerFunc(handlersApi.UserActionHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) @@ -426,7 +531,7 @@ func osctrlAPIService() { "GET "+_apiPath(apiEnvironmentsPath), handlerAuthCheck(http.HandlerFunc(handlersApi.EnvironmentsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( - "POST "+_apiPath(apiEnvironmentsPath), + "POST "+_apiPath(apiEnvironmentsPath)+"/actions", handlerAuthCheck(http.HandlerFunc(handlersApi.EnvActionsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) muxAPI.Handle( @@ -447,6 +552,33 @@ func osctrlAPIService() { muxAPI.Handle( "POST "+_apiPath(apiEnvironmentsPath)+"/{env}/remove/{action}", handlerAuthCheck(http.HandlerFunc(handlersApi.EnvRemoveActionsHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // API: environments CRUD + config (Track 8) + muxAPI.Handle( + "POST "+_apiPath(apiEnvironmentsPath), + handlerAuthCheck(http.HandlerFunc(handlersApi.EnvironmentCreateHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "PATCH "+_apiPath(apiEnvironmentsPath)+"/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.EnvironmentUpdateHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "DELETE "+_apiPath(apiEnvironmentsPath)+"/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.EnvironmentDeleteHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // Env config routes use a `/config/{env}` shape (literal in segment 1) so + // they cannot register-conflict with `/map/{target}` registered above. A + // `/{env}/config` shape would put a wildcard in segment 1 — Go's ServeMux + // refuses to accept it alongside `/map/{target}` since neither pattern + // strictly dominates the other. + muxAPI.Handle( + "GET "+_apiPath(apiEnvironmentsPath)+"/config/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.EnvironmentConfigHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "PATCH "+_apiPath(apiEnvironmentsPath)+"/config/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.EnvironmentConfigPatchHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "PATCH "+_apiPath(apiEnvironmentsPath)+"/intervals/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.EnvironmentIntervalsPatchHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + muxAPI.Handle( + "PATCH "+_apiPath(apiEnvironmentsPath)+"/expiration/{env}", + handlerAuthCheck(http.HandlerFunc(handlersApi.EnvironmentExpirationPatchHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) // API: tags by environment muxAPI.Handle( "GET "+_apiPath(apiTagsPath), @@ -476,6 +608,10 @@ func osctrlAPIService() { muxAPI.Handle( "GET "+_apiPath(apiSettingsPath)+"/{service}/json/{env}", handlerAuthCheck(http.HandlerFunc(handlersApi.SettingsServiceEnvJSONHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) + // API: settings PATCH (Track 9) + muxAPI.Handle( + "PATCH "+_apiPath(apiSettingsPath)+"/{service}/{name}", + handlerAuthCheck(http.HandlerFunc(handlersApi.SettingPatchHandler), flagParams.Service.Auth, flagParams.JWT.JWTSecret)) // API: audit log if flagParams.Service.AuditLog { muxAPI.Handle( diff --git a/pkg/auditlog/audit.go b/pkg/auditlog/audit.go index bd1d3246..fdb0e39d 100644 --- a/pkg/auditlog/audit.go +++ b/pkg/auditlog/audit.go @@ -2,11 +2,113 @@ package auditlog import ( "fmt" + "time" "github.com/rs/zerolog/log" "gorm.io/gorm" ) +// LogTypes - allowlist of valid log_type filter values. Used by the +// paginated filter to reject arbitrary integers (defense in depth — the +// underlying column is uint so junk values just match nothing, but we +// surface a 400 to the SPA instead of an empty response). +var LogTypes = map[uint]struct{}{ + LogTypeLogin: {}, + LogTypeLogout: {}, + LogTypeNode: {}, + LogTypeQuery: {}, + LogTypeCarve: {}, + LogTypeTag: {}, + LogTypeEnvironment: {}, + LogTypeSetting: {}, + LogTypeVisit: {}, + LogTypeUser: {}, +} + +// PageFilter describes the inputs accepted by GetPaged. +// +// All string fields are case-insensitive partial matches except Service +// which is an exact match (services are a tiny fixed set: tls / admin / +// osctrl-api). EnvID == 0 means "no env filter" (NOT "the no-environment +// rows" — use a dedicated convention if that's ever needed). LogType == 0 +// means "no type filter". Since / Until are RFC3339 timestamps; either may +// be the zero value to mean unset. +type PageFilter struct { + Service string + Username string + LogType uint + EnvID uint + Since time.Time + Until time.Time + Page int + PageSize int +} + +// GetPaged returns audit logs filtered + paginated. Ordering is fixed at +// created_at DESC so the SPA always shows newest first. +// +// Returns (rows, totalItems, error). On the filtered count the package +// computes that with the same WHERE clause (one extra COUNT round-trip). +func (m *AuditLogManager) GetPaged(f PageFilter) ([]AuditLog, int64, error) { + if f.PageSize <= 0 { + f.PageSize = 50 + } + if f.PageSize > 500 { + f.PageSize = 500 + } + if f.Page < 1 { + f.Page = 1 + } + + q := m.DB.Model(&AuditLog{}) + if f.Service != "" { + q = q.Where("service = ?", f.Service) + } + if f.Username != "" { + // case-insensitive partial match via LOWER(username) LIKE ... + q = q.Where("LOWER(username) LIKE ?", "%"+lowerLike(f.Username)+"%") + } + if f.LogType > 0 { + q = q.Where("log_type = ?", f.LogType) + } + if f.EnvID > 0 { + q = q.Where("environment_id = ?", f.EnvID) + } + if !f.Since.IsZero() { + q = q.Where("created_at >= ?", f.Since) + } + if !f.Until.IsZero() { + q = q.Where("created_at <= ?", f.Until) + } + + var total int64 + if err := q.Count(&total).Error; err != nil { + return nil, 0, fmt.Errorf("count AuditLog %w", err) + } + + var rows []AuditLog + offset := (f.Page - 1) * f.PageSize + if err := q.Order("created_at desc").Limit(f.PageSize).Offset(offset).Find(&rows).Error; err != nil { + return nil, 0, fmt.Errorf("paged AuditLog %w", err) + } + return rows, total, nil +} + +// lowerLike normalizes a user-supplied search fragment for LIKE matching: +// strip surrounding whitespace and lowercase. The handler is responsible +// for callers — we do not lift restrictions or accept regex. +func lowerLike(s string) string { + out := make([]byte, 0, len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if c >= 'A' && c <= 'Z' { + c += 32 + } + out = append(out, c) + } + return string(out) +} + const ( // Log types LogTypeLogin = 1 @@ -176,6 +278,18 @@ func (m *AuditLogManager) NewCarve(username, path, ip string, envID uint) { } } +// SavedQueryAction - create new saved-query action audit log entry +// (create / update / delete operations on the saved_queries table). +func (m *AuditLogManager) SavedQueryAction(username, action, ip string, envID uint) { + if !m.Enabled { + return + } + line := fmt.Sprintf("user %s performed saved-query action: %s", username, action) + if err := m.CreateNew(username, line, ip, LogTypeQuery, SeverityInfo, envID); err != nil { + log.Err(err).Msg("error creating saved-query audit log") + } +} + // QueryAction - create new query action audit log entry func (m *AuditLogManager) QueryAction(username, action, ip string, envID uint) { if !m.Enabled { @@ -331,6 +445,56 @@ func (m *AuditLogManager) GetByEnv(envID uint) ([]AuditLog, error) { return logs, nil } +// GetEnvSince — returns every audit row for the env since the given cutoff, +// log_type + created_at only (Pluck-style). Used by the activity heatmap so +// the dashboard can render a 24-hour fleet-activity strip without scanning +// the full audit_logs table. Smaller fields than GetByEnv to keep the +// payload tiny — 24 hours of a busy env is still small enough to ship to +// the SPA, but trimming to two columns keeps the SQL fast. +func (m *AuditLogManager) GetEnvSince(envID uint, since time.Time) ([]AuditLog, error) { + var logs []AuditLog + if err := m.DB. + Select("id, log_type, created_at"). + Where("environment_id = ? AND created_at >= ?", envID, since). + Order("created_at asc"). + Find(&logs).Error; err != nil { + return logs, fmt.Errorf("get AuditLog since %w", err) + } + return logs, nil +} + +// EnvActivityBucketRow is one (bucket_start, log_type, count) row returned +// from the bucketed env-activity query. +type EnvActivityBucketRow struct { + BucketStart int64 `gorm:"column:bucket_start"` + LogType uint `gorm:"column:log_type"` + Cnt int64 `gorm:"column:cnt"` +} + +// GetEnvActivityBucketed — returns audit-log counts grouped by bucket and +// log_type for one env, pushing the binning into SQL. Replaces the +// in-process histogram over GetEnvSince. +func (m *AuditLogManager) GetEnvActivityBucketed(envID uint, since time.Time, bucketSeconds int) ([]EnvActivityBucketRow, error) { + var dialect string + switch m.DB.Dialector.Name() { + case "postgres": + dialect = fmt.Sprintf("(floor(extract(epoch from created_at) / %d) * %d)::bigint", bucketSeconds, bucketSeconds) + case "mysql": + dialect = fmt.Sprintf("(FLOOR(UNIX_TIMESTAMP(created_at) / %d) * %d)", bucketSeconds, bucketSeconds) + default: + dialect = fmt.Sprintf("(CAST(strftime('%%s', created_at) AS INTEGER) / %d * %d)", bucketSeconds, bucketSeconds) + } + var rows []EnvActivityBucketRow + if err := m.DB.Model(&AuditLog{}). + Select(dialect+" AS bucket_start, log_type, COUNT(*) AS cnt"). + Where("environment_id = ? AND created_at >= ?", envID, since). + Group("bucket_start, log_type"). + Scan(&rows).Error; err != nil { + return rows, fmt.Errorf("env-activity bucketed: %w", err) + } + return rows, nil +} + // GetByType - get audit logs by type and environment func (m *AuditLogManager) GetByTypeEnv(logType, envID uint) ([]AuditLog, error) { var logs []AuditLog diff --git a/pkg/carves/carves.go b/pkg/carves/carves.go index 3e69c65e..114ecf34 100644 --- a/pkg/carves/carves.go +++ b/pkg/carves/carves.go @@ -8,6 +8,7 @@ import ( "time" "github.com/jmpsec/osctrl/pkg/config" + "github.com/jmpsec/osctrl/pkg/dbutil" "github.com/jmpsec/osctrl/pkg/types" "github.com/rs/zerolog/log" "gorm.io/gorm" @@ -253,6 +254,31 @@ func (c *Carves) GetNodeCarves(uuid string) ([]CarvedFile, error) { return carves, nil } +// GetNodeCarveTimestamps returns CreatedAt of every CarvedFile row from this +// node since the cutoff. Used by the per-node activity heatmap so it can +// bucket without dragging the full carve metadata. +func (c *Carves) GetNodeCarveTimestamps(uuid string, since time.Time) ([]time.Time, error) { + var ts []time.Time + err := c.DB.Model(&CarvedFile{}). + Where("uuid = ? AND created_at >= ?", uuid, since). + Pluck("created_at", &ts).Error + return ts, err +} + +// GetNodeCarveBucketed returns per-bucket row counts for carved_files +// rows produced by `uuid`. Same bucketing semantics as the logging-package +// variants — see pkg/dbutil.BucketExpr. +func (c *Carves) GetNodeCarveBucketed(uuid string, since time.Time, bucketSeconds int) ([]dbutil.BucketedRow, error) { + expr := dbutil.BucketExpr(c.DB, "created_at", bucketSeconds) + var rows []dbutil.BucketedRow + err := c.DB.Model(&CarvedFile{}). + Select(expr+" AS bucket_start, COUNT(*) AS cnt"). + Where("uuid = ? AND created_at >= ?", uuid, since). + Group("bucket_start"). + Scan(&rows).Error + return rows, err +} + // ChangeStatus to change the status of a carve func (c *Carves) ChangeStatus(status, sessionid string) error { carve, err := c.GetBySession(sessionid) diff --git a/pkg/carves/samples.go b/pkg/carves/samples.go new file mode 100644 index 00000000..6bd58d7f --- /dev/null +++ b/pkg/carves/samples.go @@ -0,0 +1,236 @@ +package carves + +// Starter file-carve target samples shipped with osctrl. Used by: +// - GET /api/v1/carves/samples — SPA carves/new form populates its +// path-templates row from this list so new operators have ready-made +// forensic targets to start from. +// +// Unlike query samples, carves are not seeded into a persistent library. +// A carve is an incident-response action against a specific path on +// specific nodes; operators run them ad-hoc, not on a schedule. The +// samples below are the "what would I grab first?" common targets. +// +// Coverage spans linux, darwin, windows so every platform has at least +// 6 starting templates regardless of which OS the operator's looking at. + +// CarveSampleCategory groups paths so the SPA can label them for the +// operator (Auth / Logs / Registry / etc). Closed set; new categories +// require updating the SPA's label map too. +type CarveSampleCategory string + +const ( + CarveCategoryAuth CarveSampleCategory = "auth" + CarveCategoryLogs CarveSampleCategory = "logs" + CarveCategoryRegistry CarveSampleCategory = "registry" + CarveCategoryKeychain CarveSampleCategory = "keychain" + CarveCategoryHistory CarveSampleCategory = "history" + CarveCategoryConfig CarveSampleCategory = "config" +) + +// CarveSamplePlatform — aligns with the platform buckets used elsewhere in +// osctrl. Each sample is single-platform because file paths are +// platform-specific by definition. +type CarveSamplePlatform string + +const ( + CarvePlatformLinux CarveSamplePlatform = "linux" + CarvePlatformDarwin CarveSamplePlatform = "darwin" + CarvePlatformWindows CarveSamplePlatform = "windows" +) + +// CarveSample is one starter target row. +type CarveSample struct { + Label string `json:"label"` + Path string `json:"path"` + Platform CarveSamplePlatform `json:"platform"` + Category CarveSampleCategory `json:"category"` + // Notes is a brief operator-facing description of why this file is + // worth grabbing during an investigation. Surfaced as a tooltip in + // the SPA template row. + Notes string `json:"notes"` +} + +// CarveSamples is the canonical starter library. ~24 entries across the +// three major platforms. Ordering is by platform then category so the SPA's +// template row reads in a predictable shape. +var CarveSamples = []CarveSample{ + // ── Linux — auth ─────────────────────────────────────────────────────── + { + Label: "/etc/passwd", + Path: "/etc/passwd", + Platform: CarvePlatformLinux, + Category: CarveCategoryAuth, + Notes: "Local user account database (read by every getpwnam call).", + }, + { + Label: "/etc/shadow", + Path: "/etc/shadow", + Platform: CarvePlatformLinux, + Category: CarveCategoryAuth, + Notes: "Hashed password store — root-readable only; presence in carve output confirms agent ran as root.", + }, + { + Label: "/etc/sudoers", + Path: "/etc/sudoers", + Platform: CarvePlatformLinux, + Category: CarveCategoryAuth, + Notes: "Sudo privilege configuration. Compare across hosts to spot drift.", + }, + // ── Linux — logs ─────────────────────────────────────────────────────── + { + Label: "/var/log/auth.log", + Path: "/var/log/auth.log", + Platform: CarvePlatformLinux, + Category: CarveCategoryLogs, + Notes: "SSH / sudo / PAM authentication events (Debian / Ubuntu).", + }, + { + Label: "/var/log/secure", + Path: "/var/log/secure", + Platform: CarvePlatformLinux, + Category: CarveCategoryLogs, + Notes: "SSH / sudo / PAM authentication events (RHEL / CentOS / Fedora).", + }, + { + Label: "/var/log/syslog", + Path: "/var/log/syslog", + Platform: CarvePlatformLinux, + Category: CarveCategoryLogs, + Notes: "General system messages; correlate with auth.log for a fuller timeline.", + }, + // ── Linux — history / config ─────────────────────────────────────────── + { + Label: "/root/.bash_history", + Path: "/root/.bash_history", + Platform: CarvePlatformLinux, + Category: CarveCategoryHistory, + Notes: "Root shell command history — first thing to grab on suspected compromise.", + }, + { + Label: "/etc/crontab", + Path: "/etc/crontab", + Platform: CarvePlatformLinux, + Category: CarveCategoryConfig, + Notes: "System-wide cron schedule. Check for unfamiliar entries.", + }, + { + Label: "/etc/hosts", + Path: "/etc/hosts", + Platform: CarvePlatformLinux, + Category: CarveCategoryConfig, + Notes: "Local hostname overrides. Tampered entries can redirect traffic.", + }, + + // ── macOS — auth ─────────────────────────────────────────────────────── + { + Label: "/etc/passwd", + Path: "/etc/passwd", + Platform: CarvePlatformDarwin, + Category: CarveCategoryAuth, + Notes: "Local user account database (legacy; macOS primarily uses OpenDirectory).", + }, + { + Label: "/var/db/dslocal/nodes/Default/users", + Path: "/var/db/dslocal/nodes/Default/users", + Platform: CarvePlatformDarwin, + Category: CarveCategoryAuth, + Notes: "Local user records in OpenDirectory (plist files; carve the directory).", + }, + // ── macOS — keychain / logs ──────────────────────────────────────────── + { + Label: "~/Library/Keychains", + Path: "/Users", + Platform: CarvePlatformDarwin, + Category: CarveCategoryKeychain, + Notes: "User keychain directories. Carve a specific user's path: /Users//Library/Keychains.", + }, + { + Label: "/var/log/system.log", + Path: "/var/log/system.log", + Platform: CarvePlatformDarwin, + Category: CarveCategoryLogs, + Notes: "Pre-unified-logging system messages.", + }, + { + Label: "/var/log/install.log", + Path: "/var/log/install.log", + Platform: CarvePlatformDarwin, + Category: CarveCategoryLogs, + Notes: "Software install / update events — useful for spotting unexpected pkg installs.", + }, + // ── macOS — history / config ─────────────────────────────────────────── + { + Label: "~/.zsh_history (root)", + Path: "/var/root/.zsh_history", + Platform: CarvePlatformDarwin, + Category: CarveCategoryHistory, + Notes: "Root zsh history. Adjust path for non-root users: /Users//.zsh_history.", + }, + { + Label: "/etc/hosts", + Path: "/etc/hosts", + Platform: CarvePlatformDarwin, + Category: CarveCategoryConfig, + Notes: "Local hostname overrides.", + }, + + // ── Windows — auth (registry hives) ──────────────────────────────────── + { + Label: `SAM hive`, + Path: `C:\Windows\System32\config\SAM`, + Platform: CarvePlatformWindows, + Category: CarveCategoryRegistry, + Notes: "Local account database hive. File is locked while Windows runs; carve from VSS shadow or live-running osquery as SYSTEM.", + }, + { + Label: `SYSTEM hive`, + Path: `C:\Windows\System32\config\SYSTEM`, + Platform: CarvePlatformWindows, + Category: CarveCategoryRegistry, + Notes: "System configuration hive. Contains services, drivers, BootKey for SAM decryption.", + }, + { + Label: `SECURITY hive`, + Path: `C:\Windows\System32\config\SECURITY`, + Platform: CarvePlatformWindows, + Category: CarveCategoryRegistry, + Notes: "Local security policy hive. Contains LSA secrets and cached domain credentials.", + }, + // ── Windows — logs ───────────────────────────────────────────────────── + { + Label: `Security event log`, + Path: `C:\Windows\System32\winevt\Logs\Security.evtx`, + Platform: CarvePlatformWindows, + Category: CarveCategoryLogs, + Notes: "Windows security audit log — logon events, privilege use, object access.", + }, + { + Label: `System event log`, + Path: `C:\Windows\System32\winevt\Logs\System.evtx`, + Platform: CarvePlatformWindows, + Category: CarveCategoryLogs, + Notes: "System events — services, drivers, hardware. Pairs with Security.evtx for correlation.", + }, + { + Label: `PowerShell op log`, + Path: `C:\Windows\System32\winevt\Logs\Microsoft-Windows-PowerShell%4Operational.evtx`, + Platform: CarvePlatformWindows, + Category: CarveCategoryLogs, + Notes: "PowerShell script-block and pipeline execution log. High-value for attacker activity.", + }, + // ── Windows — config ─────────────────────────────────────────────────── + { + Label: `hosts file`, + Path: `C:\Windows\System32\drivers\etc\hosts`, + Platform: CarvePlatformWindows, + Category: CarveCategoryConfig, + Notes: "Local hostname overrides. Should rarely change in a managed fleet.", + }, + { + Label: `NTUSER.DAT (per-user)`, + Path: `C:\Users`, + Platform: CarvePlatformWindows, + Category: CarveCategoryConfig, + Notes: "Per-user registry hive. Carve a specific user: C:\\Users\\\\NTUSER.DAT (locked while user is logged in).", + }, +} diff --git a/pkg/dbutil/buckets.go b/pkg/dbutil/buckets.go new file mode 100644 index 00000000..ebf89701 --- /dev/null +++ b/pkg/dbutil/buckets.go @@ -0,0 +1,78 @@ +package dbutil + +import ( + "fmt" + "time" + + "gorm.io/gorm" +) + +// BucketExpr returns the SQL expression that floors `created_at` to a +// bucket-aligned unix timestamp. Same shape on every dialect — only the +// epoch-extraction function differs. +// +// The expression returns an integer number of seconds since the epoch, +// truncated down to the nearest `bucketSeconds` boundary. Group by this +// expression, count(*), and you have a contiguous-bucket histogram. +func BucketExpr(db *gorm.DB, column string, bucketSeconds int) string { + switch db.Dialector.Name() { + case "postgres": + return fmt.Sprintf( + "(floor(extract(epoch from %s) / %d) * %d)::bigint", + column, bucketSeconds, bucketSeconds, + ) + case "mysql": + return fmt.Sprintf( + "(FLOOR(UNIX_TIMESTAMP(%s) / %d) * %d)", + column, bucketSeconds, bucketSeconds, + ) + case "sqlite": + return fmt.Sprintf( + "(CAST(strftime('%%s', %s) AS INTEGER) / %d * %d)", + column, bucketSeconds, bucketSeconds, + ) + default: + // Best-effort SQL-92-ish fallback; not all dialects accept this but + // the three supported dialects above are covered. + return fmt.Sprintf( + "(CAST(strftime('%%s', %s) AS INTEGER) / %d * %d)", + column, bucketSeconds, bucketSeconds, + ) + } +} + +// BucketCount represents one row of a bucketed count query. +type BucketCount struct { + Bucket int64 // Unix seconds at the start of the bucket + Count int64 +} + +// BucketedRow is the raw scan target for the GROUP BY query. Stays +// dialect-agnostic since every dialect returns BIGINT for FLOOR/CAST +// expressions. +type BucketedRow struct { + BucketStart int64 `gorm:"column:bucket_start"` + Cnt int64 `gorm:"column:cnt"` +} + +// DensifyBuckets takes a sparse list of {bucketStart, count} rows from the +// DB and emits a dense `nBuckets`-long slice aligned to `startUnix`. Bucket +// indexes outside the range are dropped — they can't render in a heatmap +// of fixed width. +func DensifyBuckets(rows []BucketedRow, startUnix int64, bucketSeconds int, nBuckets int) []int64 { + out := make([]int64, nBuckets) + for _, r := range rows { + idx := int((r.BucketStart - startUnix) / int64(bucketSeconds)) + if idx < 0 || idx >= nBuckets { + continue + } + out[idx] = r.Cnt + } + return out +} + +// AlignBucketStart rounds `t` down to the nearest `bucketSeconds` boundary. +// Used so the API and the rollup-writer agree on bucket edges to the second. +func AlignBucketStart(t time.Time, bucketSeconds int) time.Time { + return time.Unix((t.UTC().Unix()/int64(bucketSeconds))*int64(bucketSeconds), 0).UTC() +} diff --git a/pkg/environments/environments.go b/pkg/environments/environments.go index 848cece5..0b941ac4 100644 --- a/pkg/environments/environments.go +++ b/pkg/environments/environments.go @@ -53,46 +53,49 @@ const ( // TLSEnvironment to hold each of the TLS environment type TLSEnvironment struct { - gorm.Model - UUID string `gorm:"index"` - Name string - Hostname string - Secret string - EnrollSecretPath string - EnrollExpire time.Time - RemoveSecretPath string - RemoveExpire time.Time - Type string - DebPackage string - RpmPackage string - MsiPackage string - PkgPackage string - DebugHTTP bool - Icon string - Options string - Schedule string - Packs string - Decorators string - ATC string - Configuration string - Flags string - Certificate string - ConfigTLS bool - ConfigInterval int - LoggingTLS bool - LogInterval int - QueryTLS bool - QueryInterval int - CarvesTLS bool - EnrollPath string - LogPath string - ConfigPath string - QueryReadPath string - QueryWritePath string - CarverInitPath string - CarverBlockPath string - AcceptEnrolls bool - UserID uint + ID uint `gorm:"primarykey" json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` + UUID string `gorm:"index" json:"uuid"` + Name string `json:"name"` + Hostname string `json:"hostname"` + Secret string `json:"secret"` + EnrollSecretPath string `json:"enroll_secret_path"` + EnrollExpire time.Time `json:"enroll_expire"` + RemoveSecretPath string `json:"remove_secret_path"` + RemoveExpire time.Time `json:"remove_expire"` + Type string `json:"type"` + DebPackage string `json:"deb_package"` + RpmPackage string `json:"rpm_package"` + MsiPackage string `json:"msi_package"` + PkgPackage string `json:"pkg_package"` + DebugHTTP bool `json:"debug_http"` + Icon string `json:"icon"` + Options string `json:"options"` + Schedule string `json:"schedule"` + Packs string `json:"packs"` + Decorators string `json:"decorators"` + ATC string `json:"atc"` + Configuration string `json:"configuration"` + Flags string `json:"flags"` + Certificate string `json:"certificate"` + ConfigTLS bool `json:"config_tls"` + ConfigInterval int `json:"config_interval"` + LoggingTLS bool `json:"logging_tls"` + LogInterval int `json:"log_interval"` + QueryTLS bool `json:"query_tls"` + QueryInterval int `json:"query_interval"` + CarvesTLS bool `json:"carves_tls"` + EnrollPath string `json:"enroll_path"` + LogPath string `json:"log_path"` + ConfigPath string `json:"config_path"` + QueryReadPath string `json:"query_read_path"` + QueryWritePath string `json:"query_write_path"` + CarverInitPath string `json:"carver_init_path"` + CarverBlockPath string `json:"carver_block_path"` + AcceptEnrolls bool `json:"accept_enrolls"` + UserID uint `json:"user_id"` } // MapEnvironments to hold the TLS environments by name and UUID diff --git a/pkg/logging/db.go b/pkg/logging/db.go index 9c4e94cc..491268df 100644 --- a/pkg/logging/db.go +++ b/pkg/logging/db.go @@ -11,6 +11,7 @@ import ( "github.com/jmpsec/osctrl/pkg/backend" "github.com/jmpsec/osctrl/pkg/config" + "github.com/jmpsec/osctrl/pkg/dbutil" "github.com/jmpsec/osctrl/pkg/queries" "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/types" @@ -217,6 +218,233 @@ func (logDB *LoggerDB) ResultLogsLimit(uuid, environment string, limit int) ([]O return logs, nil } +// GetNodeLogs retrieves recent log entries for a single node (status or result). +// logType must be "status" or "result". Results are ordered by created_at DESC. +// If since is non-zero only entries created strictly after that time are returned. +// limit is clamped to [1, 1000]. +// +// search is an optional free-text filter (substring, case-insensitive). It +// runs as a `LIKE` against the human-readable text columns of the row: +// - status: line + message + filename +// - result: name + action + columns (the serialized JSON of matched fields) +// +// Empty search disables the filter — same behavior as a missing param. +// +// The `LIKE` is unindexed today. If the result_data / status_data tables +// grow large enough to make this slow, an operator-side workaround is to +// narrow `since` first, which keeps the matched row count small. +func GetNodeLogs(db *gorm.DB, logType, env, uuid string, since time.Time, limit int, search string) ([]map[string]any, error) { + if limit <= 0 { + limit = 100 + } + if limit > 1000 { + limit = 1000 + } + uuid = strings.ToUpper(uuid) + // Escape SQL LIKE wildcards in the user input so a literal '%' in a + // pasted token doesn't match more than intended. GORM auto-escapes the + // quote+backslash but not the wildcard metacharacters. + likeNeedle := "" + if search != "" { + needle := strings.ReplaceAll(search, `\`, `\\`) + needle = strings.ReplaceAll(needle, `%`, `\%`) + needle = strings.ReplaceAll(needle, `_`, `\_`) + likeNeedle = "%" + needle + "%" + } + + var result []map[string]any + + switch logType { + case types.StatusLog: + var rows []OsqueryStatusData + q := db.Where("uuid = ? AND environment = ?", uuid, env) + if !since.IsZero() { + q = q.Where("created_at > ?", since) + } + if likeNeedle != "" { + // LOWER() so the search is case-insensitive. The needle is + // already plain-text; lowercasing both sides handles UTF-8 + // only weakly (no Unicode case-folding) but is good enough + // for the IR/incident use case which is mostly ASCII tokens. + lowerNeedle := strings.ToLower(likeNeedle) + q = q.Where( + "LOWER(line) LIKE ? OR LOWER(message) LIKE ? OR LOWER(filename) LIKE ?", + lowerNeedle, lowerNeedle, lowerNeedle, + ) + } + if err := q.Order("created_at DESC").Limit(limit).Find(&rows).Error; err != nil { + return nil, err + } + for _, r := range rows { + result = append(result, map[string]any{ + "id": r.ID, + "created_at": r.CreatedAt, + "uuid": r.UUID, + "environment": r.Environment, + "line": r.Line, + "message": r.Message, + "version": r.Version, + "filename": r.Filename, + "severity": r.Severity, + }) + } + case types.ResultLog: + var rows []OsqueryResultData + q := db.Where("uuid = ? AND environment = ?", uuid, env) + if !since.IsZero() { + q = q.Where("created_at > ?", since) + } + if likeNeedle != "" { + lowerNeedle := strings.ToLower(likeNeedle) + q = q.Where( + "LOWER(name) LIKE ? OR LOWER(action) LIKE ? OR LOWER(columns) LIKE ?", + lowerNeedle, lowerNeedle, lowerNeedle, + ) + } + if err := q.Order("created_at DESC").Limit(limit).Find(&rows).Error; err != nil { + return nil, err + } + for _, r := range rows { + result = append(result, map[string]any{ + "id": r.ID, + "created_at": r.CreatedAt, + "uuid": r.UUID, + "environment": r.Environment, + "name": r.Name, + "action": r.Action, + "epoch": r.Epoch, + "columns": r.Columns, + "counter": r.Counter, + }) + } + default: + return nil, fmt.Errorf("invalid log type: %s", logType) + } + + return result, nil +} + +// GetNodeStatusTimestamps and GetNodeResultTimestamps return just the +// CreatedAt column for every status/result log row a given node has shipped +// since `since`. Used by the per-node activity heatmap so it can bucket on +// the API side without dragging the row bodies across the wire. +// +// Returning a slice of timestamps (rather than int64 epochs) keeps the +// downstream bucketing arithmetic in Go's time domain, which is what the +// rest of cmd/api/handlers/stats.go uses. +func GetNodeStatusTimestamps(db *gorm.DB, env, uuid string, since time.Time) ([]time.Time, error) { + uuid = strings.ToUpper(uuid) + var ts []time.Time + err := db.Model(&OsqueryStatusData{}). + Where("uuid = ? AND environment = ? AND created_at >= ?", uuid, env, since). + Pluck("created_at", &ts).Error + return ts, err +} + +func GetNodeResultTimestamps(db *gorm.DB, env, uuid string, since time.Time) ([]time.Time, error) { + uuid = strings.ToUpper(uuid) + var ts []time.Time + err := db.Model(&OsqueryResultData{}). + Where("uuid = ? AND environment = ? AND created_at >= ?", uuid, env, since). + Pluck("created_at", &ts).Error + return ts, err +} + +// GetNodeStatusBucketed returns per-bucket row counts for `uuid` in `env` +// since `since`, with buckets aligned to `bucketSeconds`. The SQL pushes the +// histogram into the database (one GROUP BY) instead of shipping every +// timestamp to the API process — orders of magnitude less wire traffic on +// chatty nodes. +func GetNodeStatusBucketed(db *gorm.DB, env, uuid string, since time.Time, bucketSeconds int) ([]dbutil.BucketedRow, error) { + uuid = strings.ToUpper(uuid) + expr := dbutil.BucketExpr(db, "created_at", bucketSeconds) + var rows []dbutil.BucketedRow + err := db.Model(&OsqueryStatusData{}). + Select(expr+" AS bucket_start, COUNT(*) AS cnt"). + Where("uuid = ? AND environment = ? AND created_at >= ?", uuid, env, since). + Group("bucket_start"). + Scan(&rows).Error + return rows, err +} + +// GetNodeResultBucketed mirrors GetNodeStatusBucketed for osquery_result_data. +func GetNodeResultBucketed(db *gorm.DB, env, uuid string, since time.Time, bucketSeconds int) ([]dbutil.BucketedRow, error) { + uuid = strings.ToUpper(uuid) + expr := dbutil.BucketExpr(db, "created_at", bucketSeconds) + var rows []dbutil.BucketedRow + err := db.Model(&OsqueryResultData{}). + Select(expr+" AS bucket_start, COUNT(*) AS cnt"). + Where("uuid = ? AND environment = ? AND created_at >= ?", uuid, env, since). + Group("bucket_start"). + Scan(&rows).Error + return rows, err +} + +// GetQueryResults retrieves rows of query result data (one per node) for a single query name. +// Results are ordered by created_at ASC (oldest first — query results are append-only). +// If since is non-zero only rows created strictly after that time are returned. +// page is 1-indexed; pageSize is clamped to [1, 1000]; pageSize <= 0 defaults to 100. +// Returns the page items, total matching rows, and any error. +func GetQueryResults(db *gorm.DB, name string, since time.Time, page, pageSize int) ([]map[string]any, int64, error) { + if pageSize <= 0 { + pageSize = 100 + } + if pageSize > 1000 { + pageSize = 1000 + } + if page <= 0 { + page = 1 + } + offset := (page - 1) * pageSize + + q := db.Model(&OsqueryQueryData{}).Where("name = ?", name) + if !since.IsZero() { + q = q.Where("created_at > ?", since) + } + var total int64 + if err := q.Count(&total).Error; err != nil { + return nil, 0, err + } + var rows []OsqueryQueryData + if err := q.Order("created_at ASC").Offset(offset).Limit(pageSize).Find(&rows).Error; err != nil { + return nil, 0, err + } + out := make([]map[string]any, 0, len(rows)) + for _, r := range rows { + out = append(out, map[string]any{ + "id": r.ID, + "created_at": r.CreatedAt, + "uuid": r.UUID, + "environment": r.Environment, + "name": r.Name, + "data": r.Data, + "status": r.Status, + }) + } + return out, total, nil +} + +// StreamQueryResults invokes fn for each row of query result data for `name`, ordered by created_at ASC. +// Rows are read via a cursor so memory usage stays bounded — used by the CSV exporter. +// fn may return an error to stop iteration; that error is returned by StreamQueryResults. +func StreamQueryResults(db *gorm.DB, name string, fn func(OsqueryQueryData) error) error { + rows, err := db.Model(&OsqueryQueryData{}).Where("name = ?", name).Order("created_at ASC").Rows() + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var r OsqueryQueryData + if err := db.ScanRows(rows, &r); err != nil { + return err + } + if err := fn(r); err != nil { + return err + } + } + return rows.Err() +} + // CleanStatusLogs will delete old status logs func (logDB *LoggerDB) CleanStatusLogs(environment string, seconds int64) error { minusSeconds := time.Now().Add(time.Duration(-seconds) * time.Second) diff --git a/pkg/nodes/models.go b/pkg/nodes/models.go index e1192fad..cb09fbcc 100644 --- a/pkg/nodes/models.go +++ b/pkg/nodes/models.go @@ -8,57 +8,63 @@ import ( // OsqueryNode as abstraction of a node type OsqueryNode struct { - gorm.Model - NodeKey string `gorm:"index"` - UUID string `gorm:"index"` - Platform string - PlatformVersion string - OsqueryVersion string - Hostname string - Localname string - IPAddress string - Username string - OsqueryUser string - Environment string - CPU string - Memory string - HardwareSerial string - DaemonHash string - ConfigHash string - BytesReceived int - RawEnrollment string - LastSeen time.Time - UserID uint - EnvironmentID uint - ExtraData string + ID uint `gorm:"primarykey" json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` + NodeKey string `gorm:"index" json:"-"` + UUID string `gorm:"index" json:"uuid"` + Platform string `json:"platform"` + PlatformVersion string `json:"platform_version"` + OsqueryVersion string `json:"osquery_version"` + Hostname string `json:"hostname"` + Localname string `json:"localname"` + IPAddress string `json:"ip_address"` + Username string `json:"username"` + OsqueryUser string `json:"osquery_user"` + Environment string `json:"environment"` + CPU string `json:"cpu"` + Memory string `json:"memory"` + HardwareSerial string `json:"hardware_serial"` + DaemonHash string `json:"daemon_hash"` + ConfigHash string `json:"config_hash"` + BytesReceived int `json:"bytes_received"` + RawEnrollment string `json:"-"` + LastSeen time.Time `json:"last_seen"` + UserID uint `json:"user_id"` + EnvironmentID uint `json:"environment_id"` + ExtraData string `json:"extra_data"` } // ArchiveOsqueryNode as abstraction of an archived node type ArchiveOsqueryNode struct { - gorm.Model - NodeKey string `gorm:"index"` - UUID string `gorm:"index"` - Trigger string - Platform string - PlatformVersion string - OsqueryVersion string - Hostname string - Localname string - IPAddress string - Username string - OsqueryUser string - Environment string - CPU string - Memory string - HardwareSerial string - ConfigHash string - DaemonHash string - BytesReceived int - RawEnrollment string - LastSeen time.Time - UserID uint - EnvironmentID uint - ExtraData string + ID uint `gorm:"primarykey" json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` + NodeKey string `gorm:"index" json:"-"` + UUID string `gorm:"index" json:"uuid"` + Trigger string `json:"trigger"` + Platform string `json:"platform"` + PlatformVersion string `json:"platform_version"` + OsqueryVersion string `json:"osquery_version"` + Hostname string `json:"hostname"` + Localname string `json:"localname"` + IPAddress string `json:"ip_address"` + Username string `json:"username"` + OsqueryUser string `json:"osquery_user"` + Environment string `json:"environment"` + CPU string `json:"cpu"` + Memory string `json:"memory"` + HardwareSerial string `json:"hardware_serial"` + ConfigHash string `json:"config_hash"` + DaemonHash string `json:"daemon_hash"` + BytesReceived int `json:"bytes_received"` + RawEnrollment string `json:"-"` + LastSeen time.Time `json:"last_seen"` + UserID uint `json:"user_id"` + EnvironmentID uint `json:"environment_id"` + ExtraData string `json:"extra_data"` } // NodeMetadata to hold metadata for a node diff --git a/pkg/nodes/nodes.go b/pkg/nodes/nodes.go index 75e11de9..e8fce088 100644 --- a/pkg/nodes/nodes.go +++ b/pkg/nodes/nodes.go @@ -198,35 +198,6 @@ func (n *NodeManager) GetByEnv(env, target string, hours int64) ([]OsqueryNode, return nodes, nil } -// GetByEnvPage retrieves a page of nodes by environment applying target filters using LIMIT/OFFSET -func (n *NodeManager) GetByEnvPage(env, target string, hours int64, offset, limit int, orderBy string, desc bool) ([]OsqueryNode, error) { - var nodes []OsqueryNode - if limit <= 0 { // safety default - limit = 25 - } - if limit > 500 { // cap to avoid abuse - limit = 500 - } - if offset < 0 { - offset = 0 - } - query := n.DB.Where("environment = ?", env) - query = ApplyNodeTarget(query, target, hours) - // Default ordering only if client did not request a specific column - orderExpr := "last_seen DESC" - if orderBy != "" { - direction := "ASC" - if desc { - direction = "DESC" - } - orderExpr = orderBy + " " + direction - } - if err := query.Order(orderExpr).Offset(offset).Limit(limit).Find(&nodes).Error; err != nil { - return nodes, err - } - return nodes, nil -} - // CountByEnvTarget counts nodes for an environment after applying target (active/inactive/all) func (n *NodeManager) CountByEnvTarget(env string, target string, hours int64) (int64, error) { var count int64 @@ -253,34 +224,6 @@ func (n *NodeManager) SearchByEnv(env, term, target string, hours int64) ([]Osqu return nodes, nil } -// SearchByEnvPage performs a paginated search -func (n *NodeManager) SearchByEnvPage(env, term, target string, hours int64, offset, limit int, orderBy string, desc bool) ([]OsqueryNode, error) { - if limit <= 0 { - limit = 25 - } else if limit > 500 { - limit = 500 - } - if offset < 0 { - offset = 0 - } - var nodes []OsqueryNode - likeTerm := "%" + term + "%" - query := n.DB.Where("environment = ? AND (uuid LIKE ? OR hostname LIKE ? OR localname LIKE ? OR ip_address LIKE ? OR username LIKE ? OR osquery_user LIKE ? OR platform LIKE ? OR osquery_version LIKE ?)", env, likeTerm, likeTerm, likeTerm, likeTerm, likeTerm, likeTerm, likeTerm, likeTerm) - query = ApplyNodeTarget(query, target, hours) - orderExpr := "last_seen DESC" - if orderBy != "" { - direction := "ASC" - if desc { - direction = "DESC" - } - orderExpr = orderBy + " " + direction - } - if err := query.Order(orderExpr).Offset(offset).Limit(limit).Find(&nodes).Error; err != nil { - return nodes, err - } - return nodes, nil -} - // CountSearchByEnv counts matching nodes for a search term with target filters func (n *NodeManager) CountSearchByEnv(env, term, target string, hours int64) (int64, error) { likeTerm := "%" + term + "%" @@ -395,6 +338,17 @@ func (n *NodeManager) GetStatsByEnv(environment string, hours int64) (StatsData, return GetStats(n.DB, EnvironmentSelector, environment, hours) } +// GetPlatformCountsByEnv exposes the package-level helper through NodeManager +// so handlers don't reach into n.DB directly. +func (n *NodeManager) GetPlatformCountsByEnv(environment string) (PlatformCounts, error) { + return GetPlatformCountsByEnv(n.DB, environment) +} + +// GetOsqueryVersionCounts wrapper. +func (n *NodeManager) GetOsqueryVersionCounts() ([]OsqueryVersionCount, error) { + return GetOsqueryVersionCounts(n.DB) +} + // UpdateMetadataByUUID to update node metadata by UUID func (n *NodeManager) UpdateMetadataByUUID(uuid string, metadata NodeMetadata) error { // Retrieve node @@ -550,6 +504,153 @@ func (n *NodeManager) MetadataRefresh(node OsqueryNode, updates map[string]inter return n.DB.Model(&node).Updates(updates).Error } +// SortableColumns is the closed set of columns that may be ordered by external +// callers. Enforced in GetByEnvPaged so the allowlist is part of the data layer, +// not just the HTTP handler. Resolves audit finding U-DB-1. +var SortableColumns = map[string]string{ + "uuid": "uuid", + "hostname": "hostname", + "localname": "localname", + "ip": "ip_address", + "platform": "platform", + "version": "platform_version", + "osquery": "osquery_version", + "lastseen": "last_seen", + "firstseen": "created_at", +} + +// NodesPage is the canonical paginated-list result for nodes. +type NodesPage struct { + Items []OsqueryNode + TotalItems int64 +} + +// GetByEnvPaged returns a page of nodes for an environment, applying the target +// filter (all / active / inactive), optional search, optional sort, and the +// optional platform bucket filter ("linux" / "darwin" / "windows" / "other"). +// The sort column is validated against SortableColumns; unknown columns fall +// back to last_seen DESC. This is the single canonical paginated reader. +// +// page is 1-indexed. pageSize is clamped to [1, 500] with default 50. +// platformBucket is one of the buckets normalizePlatformBucket recognises; an +// empty string disables the filter. Unknown buckets also disable it (so the +// caller can pass user input directly without input-validation boilerplate). +func (n *NodeManager) GetByEnvPaged(env, target string, hours int64, search string, page, pageSize int, sortColumn string, desc bool, platformBucket string) (NodesPage, error) { + if pageSize <= 0 { + pageSize = 50 + } + if pageSize > 500 { + pageSize = 500 + } + if page <= 0 { + page = 1 + } + offset := (page - 1) * pageSize + + // Resolve sort column against the package allowlist; fall back to last_seen + // if the caller asked for something we don't allow. + dbColumn, ok := SortableColumns[sortColumn] + if !ok || sortColumn == "" { + dbColumn = "last_seen" + desc = true + } + direction := "ASC" + if desc { + direction = "DESC" + } + // dbColumn is always from the allowlist — safe to interpolate. + orderExpr := fmt.Sprintf("%s %s", dbColumn, direction) + + // Build the base query + query := n.DB.Model(&OsqueryNode{}).Where("environment = ?", env) + query = ApplyNodeTarget(query, target, hours) + query = applyPlatformBucket(query, platformBucket) + if search != "" { + like := "%" + search + "%" + query = query.Where( + "uuid LIKE ? OR hostname LIKE ? OR localname LIKE ? OR ip_address LIKE ? OR username LIKE ? OR osquery_user LIKE ? OR platform LIKE ? OR osquery_version LIKE ?", + like, like, like, like, like, like, like, like, + ) + } + + var total int64 + if err := query.Count(&total).Error; err != nil { + return NodesPage{}, err + } + + var items []OsqueryNode + if err := query.Order(orderExpr).Offset(offset).Limit(pageSize).Find(&items).Error; err != nil { + return NodesPage{}, err + } + return NodesPage{Items: items, TotalItems: total}, nil +} + +// safeOrderExpr translates a caller-supplied orderBy column name into a +// safe `ORDER BY [ASC|DESC]` expression. The column name is +// gated by SortableColumns (the same allowlist GetByEnvPaged uses); an +// unknown/empty key falls back to the default `last_seen DESC` rather +// than splicing user input into SQL. +func safeOrderExpr(orderBy string, desc bool) string { + if orderBy == "" { + return "last_seen DESC" + } + col, ok := SortableColumns[orderBy] + if !ok { + return "last_seen DESC" + } + dir := "ASC" + if desc { + dir = "DESC" + } + return col + " " + dir +} + +// Deprecated: prefer GetByEnvPaged which applies the column allowlist at +// the package layer and unifies search, paging, and sorting into a +// single call. Retained for the legacy admin UI's callers in +// cmd/admin/handlers/json-nodes.go; the orderBy parameter is gated by +// SortableColumns so an unknown column silently falls back to +// `last_seen DESC` rather than interpolating into SQL. +func (n *NodeManager) GetByEnvPage(env, target string, hours int64, offset, limit int, orderBy string, desc bool) ([]OsqueryNode, error) { + var nodeList []OsqueryNode + if limit <= 0 { // safety default + limit = 25 + } + if limit > 500 { // cap to avoid abuse + limit = 500 + } + if offset < 0 { + offset = 0 + } + query := n.DB.Where("environment = ?", env) + query = ApplyNodeTarget(query, target, hours) + if err := query.Order(safeOrderExpr(orderBy, desc)).Offset(offset).Limit(limit).Find(&nodeList).Error; err != nil { + return nodeList, err + } + return nodeList, nil +} + +// Deprecated: prefer GetByEnvPaged. Same orderBy hardening as +// GetByEnvPage. +func (n *NodeManager) SearchByEnvPage(env, term, target string, hours int64, offset, limit int, orderBy string, desc bool) ([]OsqueryNode, error) { + if limit <= 0 { + limit = 25 + } else if limit > 500 { + limit = 500 + } + if offset < 0 { + offset = 0 + } + var nodeList []OsqueryNode + likeTerm := "%" + term + "%" + query := n.DB.Where("environment = ? AND (uuid LIKE ? OR hostname LIKE ? OR localname LIKE ? OR ip_address LIKE ? OR username LIKE ? OR osquery_user LIKE ? OR platform LIKE ? OR osquery_version LIKE ?)", env, likeTerm, likeTerm, likeTerm, likeTerm, likeTerm, likeTerm, likeTerm, likeTerm) + query = ApplyNodeTarget(query, target, hours) + if err := query.Order(safeOrderExpr(orderBy, desc)).Offset(offset).Limit(limit).Find(&nodeList).Error; err != nil { + return nodeList, err + } + return nodeList, nil +} + // CountAll to count all nodes func (n *NodeManager) CountAll() (int64, error) { var count int64 diff --git a/pkg/nodes/nodes_test.go b/pkg/nodes/nodes_test.go new file mode 100644 index 00000000..90795cf2 --- /dev/null +++ b/pkg/nodes/nodes_test.go @@ -0,0 +1,77 @@ +package nodes + +import "testing" + +// TestSortableColumnsAllowlist verifies that every entry in SortableColumns +// maps to a non-empty database column name and that the SPA-critical keys +// resolve to the expected columns. +func TestSortableColumnsAllowlist(t *testing.T) { + // Every key must map to a non-empty db column. + for k, v := range SortableColumns { + if v == "" { + t.Errorf("SortableColumns[%q] is empty", k) + } + } + + // Spot-check the contract used by the SPA. + cases := map[string]string{ + "uuid": "uuid", + "lastseen": "last_seen", + "firstseen": "created_at", + "ip": "ip_address", + "hostname": "hostname", + "localname": "localname", + "platform": "platform", + "version": "platform_version", + "osquery": "osquery_version", + } + for k, want := range cases { + got, ok := SortableColumns[k] + if !ok { + t.Errorf("SortableColumns missing expected key %q", k) + continue + } + if got != want { + t.Errorf("SortableColumns[%q] = %q, want %q", k, got, want) + } + } +} + +func TestSortableColumnsRejectsUnknown(t *testing.T) { + if _, ok := SortableColumns["unknown_column"]; ok { + t.Error("SortableColumns should not contain unknown_column") + } + if _, ok := SortableColumns[""]; ok { + t.Error("SortableColumns should not contain the empty key") + } + if _, ok := SortableColumns["DROP TABLE"]; ok { + t.Error("SortableColumns should not contain SQL fragments") + } +} + +// TestSafeOrderExpr verifies the deprecated GetByEnvPage / SearchByEnvPage +// callers can never inject SQL via orderBy — unknown / empty / malicious +// values all fall back to the safe default. +func TestSafeOrderExpr(t *testing.T) { + cases := []struct { + name string + orderBy string + desc bool + want string + }{ + {"empty falls back", "", false, "last_seen DESC"}, + {"unknown column falls back", "DROP TABLE", true, "last_seen DESC"}, + {"injection attempt falls back", "1; SELECT 1", false, "last_seen DESC"}, + // uuid is in SortableColumns + {"allowlisted asc", "uuid", false, "uuid ASC"}, + {"allowlisted desc", "uuid", true, "uuid DESC"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := safeOrderExpr(tc.orderBy, tc.desc) + if got != tc.want { + t.Errorf("safeOrderExpr(%q, %v) = %q, want %q", tc.orderBy, tc.desc, got, tc.want) + } + }) + } +} diff --git a/pkg/nodes/utils.go b/pkg/nodes/utils.go index 605bf1b9..4690cfad 100644 --- a/pkg/nodes/utils.go +++ b/pkg/nodes/utils.go @@ -71,3 +71,131 @@ func GetStats(db *gorm.DB, column, value string, hours int64) (StatsData, error) return stats, nil } + +// PlatformCounts buckets nodes by `platform` value. Three families are +// normalized into the canonical osquery-side names; everything else lands in +// Other. The buckets mirror what the SPA's Nodes-table QuickFilters chip row +// shows ([Linux] [Windows] [macOS] [Other]). +type PlatformCounts struct { + Linux int64 `json:"linux"` + Darwin int64 `json:"darwin"` + Windows int64 `json:"windows"` + Other int64 `json:"other"` +} + +// OsqueryVersionCount is one row of the osquery-versions breakdown. Used by +// the dashboard's "agent fleet hygiene" panel to spot stale agents. +type OsqueryVersionCount struct { + Version string `json:"version"` + Count int64 `json:"count"` +} + +// GetOsqueryVersionCounts returns the per-version node counts across every +// environment the caller's already filtered down to (no env arg — the dashboard +// renders fleet-wide; if a per-env variant is wanted later it lives next to +// this one). Sorted by count DESC so the most-common version sits first. +// One GROUP BY query. +func GetOsqueryVersionCounts(db *gorm.DB) ([]OsqueryVersionCount, error) { + var rows []OsqueryVersionCount + err := db.Model(&OsqueryNode{}). + Select("osquery_version AS version, COUNT(*) AS count"). + Where("osquery_version <> ''"). + Group("osquery_version"). + Order("count DESC"). + Scan(&rows).Error + if err != nil { + return nil, err + } + return rows, nil +} + +// GetPlatformCountsByEnv returns the per-platform node counts for one env. +// One GROUP BY `platform` query, then we bucket the rows in Go because +// osquery agents report `kali`, `ubuntu`, `centos`, etc. — all of which +// collapse into the `linux` bucket. Doing the mapping client-side keeps the +// SQL portable and easy to extend. +// +// Counts include both active and inactive nodes — that's the right shape for +// a "this env runs 12 Linux boxes" filter chip; "how many of those are active +// right now" lives on StatsData and is rendered separately. +func GetPlatformCountsByEnv(db *gorm.DB, environment string) (PlatformCounts, error) { + var rows []struct { + Platform string + N int64 + } + err := db.Model(&OsqueryNode{}). + Select("platform, COUNT(*) AS n"). + Where("environment = ?", environment). + Group("platform"). + Scan(&rows).Error + var out PlatformCounts + if err != nil { + return out, err + } + for _, r := range rows { + switch normalizePlatformBucket(r.Platform) { + case "linux": + out.Linux += r.N + case "darwin": + out.Darwin += r.N + case "windows": + out.Windows += r.N + default: + out.Other += r.N + } + } + return out, nil +} + +// platformsByBucket is the inverse of normalizePlatformBucket — given a +// canonical bucket name, return the literal `platform` column values that +// belong in it. Used by applyPlatformBucket to add an `IN (...)` filter. +// Kept in sync with normalizePlatformBucket; the two functions share the +// list of recognised distros so a change here without one there would +// silently mis-bucket nodes. +var platformsByBucket = map[string][]string{ + "linux": { + "linux", "kali", "ubuntu", "debian", "centos", "rhel", "fedora", + "arch", "amzn", "amazon", "opensuse", "sles", "alpine", "rocky", + "oracle", "almalinux", + }, + "darwin": {"darwin", "macos", "mac"}, + "windows": {"windows", "win", "win32", "win64"}, +} + +// applyPlatformBucket narrows a node query to one of the four buckets. +// Empty / unknown bucket → no filter (passthrough). +// "other" is the negation of (linux ∪ darwin ∪ windows): every platform that +// doesn't appear in any known list. Implemented as `platform NOT IN (...)`. +func applyPlatformBucket(q *gorm.DB, bucket string) *gorm.DB { + if bucket == "" { + return q + } + if vals, ok := platformsByBucket[bucket]; ok { + return q.Where("platform IN ?", vals) + } + if bucket == "other" { + // Everything not in any recognised bucket. + all := make([]string, 0, 32) + for _, vals := range platformsByBucket { + all = append(all, vals...) + } + return q.Where("platform NOT IN ?", all) + } + // Unknown bucket — caller can pass user input safely; no filter applied. + return q +} + +// normalizePlatformBucket folds the osquery-reported platform string into the +// SPA-facing buckets. Reads from platformsByBucket so we only maintain one +// list of recognised distros. Anything not in any bucket lands in "other". +func normalizePlatformBucket(p string) string { + for bucket, vals := range platformsByBucket { + for _, v := range vals { + if v == p { + return bucket + } + } + } + return "other" +} diff --git a/pkg/osquery/tables.go b/pkg/osquery/tables.go new file mode 100644 index 00000000..9b0854d2 --- /dev/null +++ b/pkg/osquery/tables.go @@ -0,0 +1,34 @@ +// Package osquery provides shared helpers for working with the osquery schema. +package osquery + +import ( + "encoding/json" + "os" + "strings" + + "github.com/jmpsec/osctrl/pkg/types" +) + +// LoadTables reads the osquery schema JSON file at path and returns a slice of +// OsqueryTable values. It mirrors the logic previously inlined in +// cmd/admin/utils.go loadOsqueryTables so both admin and api can share it. +func LoadTables(path string) ([]types.OsqueryTable, error) { + b, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var tables []types.OsqueryTable + if err := json.Unmarshal(b, &tables); err != nil { + return nil, err + } + // Build the filter string used for platform-based CSS filtering in the + // legacy admin templates. Kept here for parity; the API returns it too. + for i, t := range tables { + filter := "" + for _, p := range t.Platforms { + filter += " filter-" + p + } + tables[i].Filter = strings.TrimSpace(filter) + } + return tables, nil +} diff --git a/pkg/queries/queries.go b/pkg/queries/queries.go index d1beaf3b..9575b748 100644 --- a/pkg/queries/queries.go +++ b/pkg/queries/queries.go @@ -4,11 +4,30 @@ import ( "fmt" "time" + "github.com/jmpsec/osctrl/pkg/dbutil" "github.com/jmpsec/osctrl/pkg/nodes" "github.com/rs/zerolog/log" "gorm.io/gorm" ) +// QueryListPage is the canonical paginated-list result for queries. +type QueryListPage struct { + Items []DistributedQuery + TotalItems int64 +} + +// QuerySortableColumns is the closed set of columns external callers may sort by. +// Enforced in GetByEnvTargetPaged. Mirrors the SortableColumns convention from pkg/nodes. +var QuerySortableColumns = map[string]string{ + "name": "name", + "creator": "creator", + "created": "created_at", + "type": "type", + "expected": "expected", + "executions": "executions", + "errors": "errors", +} + const ( // QueryTargetPlatform defines platform as target QueryTargetPlatform string = "platform" @@ -65,27 +84,36 @@ const ( DistributedQueryStatusExpired string = "expired" ) -// DistributedQuery as abstraction of a distributed query +// DistributedQuery as abstraction of a distributed query. +// +// Explicit JSON tags (rather than relying on Go's default-PascalCase +// behavior or an external view projection) so /api/v1/queries and +// /api/v1/carves responses match the SPA's snake_case contract directly. +// Fields here are equivalent to embedding gorm.Model — same schema and +// soft-delete semantics — just with field-level json tags. type DistributedQuery struct { - gorm.Model - Name string `gorm:"not null;unique;index"` - Creator string - Query string - Expected int - Executions int - Errors int - Active bool - Hidden bool - Protected bool - Completed bool - Deleted bool - Expired bool - Type string - Path string - EnvironmentID uint - ExtraData string - Expiration time.Time - Target string + ID uint `gorm:"primarykey" json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` + Name string `gorm:"not null;unique;index" json:"name"` + Creator string `json:"creator"` + Query string `json:"query"` + Expected int `json:"expected"` + Executions int `json:"executions"` + Errors int `json:"errors"` + Active bool `json:"active"` + Hidden bool `json:"hidden"` + Protected bool `json:"protected"` + Completed bool `json:"completed"` + Deleted bool `json:"deleted"` + Expired bool `json:"expired"` + Type string `json:"type"` + Path string `json:"path"` + EnvironmentID uint `json:"environment_id"` + ExtraData string `json:"extra_data"` + Expiration time.Time `json:"expiration"` + Target string `json:"target"` } // NodeQuery links a node to a query @@ -287,6 +315,35 @@ func (q *Queries) Get(name string, envid uint) (DistributedQuery, error) { return query, nil } +// GetNodeQueryTimestamps returns just the CreatedAt of every node_query row +// where this node was the target, since the cutoff. Used by the per-node +// activity heatmap. +// +// Pluck-style — drags only one column across the wire so the heatmap stays +// cheap when nodes have many tens of thousands of distributed queries. +func (q *Queries) GetNodeQueryTimestamps(nodeID uint, since time.Time) ([]time.Time, error) { + var ts []time.Time + err := q.DB.Model(&NodeQuery{}). + Where("node_id = ? AND created_at >= ?", nodeID, since). + Pluck("created_at", &ts).Error + return ts, err +} + +// GetNodeQueryBucketed returns per-bucket row counts for node_queries +// targeting `nodeID`, since `since`. Same bucketing semantics as the +// logging-package variants — see pkg/dbutil.BucketExpr for the dialect +// branching. +func (q *Queries) GetNodeQueryBucketed(nodeID uint, since time.Time, bucketSeconds int) ([]dbutil.BucketedRow, error) { + expr := dbutil.BucketExpr(q.DB, "created_at", bucketSeconds) + var rows []dbutil.BucketedRow + err := q.DB.Model(&NodeQuery{}). + Select(expr+" AS bucket_start, COUNT(*) AS cnt"). + Where("node_id = ? AND created_at >= ?", nodeID, since). + Group("bucket_start"). + Scan(&rows).Error + return rows, err +} + // Complete to mark query as completed func (q *Queries) Complete(name string, envid uint) error { query, err := q.Get(name, envid) @@ -517,3 +574,74 @@ func (q *Queries) SetNodeQueriesAsExpired(queryID uint) error { return nil } + +// GetByEnvTargetPaged returns a page of queries for an env + target, +// with optional free-text search on name/creator/query, optional sort, and +// canonical pagination. qtype: StandardQueryType or CarveQueryType. +// +// page is 1-indexed. pageSize is clamped to [1, 500] with default 50. +func (q *Queries) GetByEnvTargetPaged(envID uint, target, qtype, search string, page, pageSize int, sortColumn string, desc bool) (QueryListPage, error) { + if pageSize <= 0 { + pageSize = 50 + } + if pageSize > 500 { + pageSize = 500 + } + if page <= 0 { + page = 1 + } + offset := (page - 1) * pageSize + + dbCol, ok := QuerySortableColumns[sortColumn] + if !ok || sortColumn == "" { + dbCol = "created_at" + desc = true + } + dir := "ASC" + if desc { + dir = "DESC" + } + orderExpr := fmt.Sprintf("%s %s", dbCol, dir) + + db := q.DB.Model(&DistributedQuery{}).Where("environment_id = ? AND type = ?", envID, qtype) + // Apply the same target filtering as Gets(): + switch target { + case TargetActive: + db = db.Where("active = ? AND completed = ? AND deleted = ? AND expired = ?", true, false, false, false) + case TargetCompleted: + db = db.Where("active = ? AND completed = ? AND deleted = ? AND expired = ?", false, true, false, false) + case TargetHiddenCompleted: + db = db.Where("active = ? AND completed = ? AND deleted = ? AND hidden = ?", false, true, false, true) + case TargetAllFull: + db = db.Where("deleted = ?", false) + case TargetAll: + db = db.Where("deleted = ? AND hidden = ?", false, false) + case TargetDeleted: + db = db.Where("deleted = ?", true) + case TargetHidden: + db = db.Where("deleted = ? AND hidden = ?", false, true) + case TargetExpired: + db = db.Where("active = ? AND expired = ? AND deleted = ?", false, true, false) + case TargetSaved: + // Saved queries are not yet implemented as a separate table (Track 4 will). + // Mirror Gets() semantics by returning zero rows here. + db = db.Where("1 = 0") + default: + return QueryListPage{}, fmt.Errorf("invalid target %q", target) + } + + if search != "" { + like := "%" + search + "%" + db = db.Where("name LIKE ? OR creator LIKE ? OR query LIKE ?", like, like, like) + } + + var total int64 + if err := db.Count(&total).Error; err != nil { + return QueryListPage{}, err + } + var items []DistributedQuery + if err := db.Order(orderExpr).Offset(offset).Limit(pageSize).Find(&items).Error; err != nil { + return QueryListPage{}, err + } + return QueryListPage{Items: items, TotalItems: total}, nil +} diff --git a/pkg/queries/queries_test.go b/pkg/queries/queries_test.go index 284a200f..a1d2ca17 100644 --- a/pkg/queries/queries_test.go +++ b/pkg/queries/queries_test.go @@ -37,14 +37,14 @@ func setupTestData(t *testing.T, db *gorm.DB) (*queries.Queries, []nodes.Osquery // Create test nodes testNodes := []nodes.OsqueryNode{ - {Model: gorm.Model{ID: 1}}, - {Model: gorm.Model{ID: 2}}, - {Model: gorm.Model{ID: 3}}, + {ID: 1}, + {ID: 2}, + {ID: 3}, } // Create test query testQuery := &queries.DistributedQuery{ - Model: gorm.Model{ID: 1}, + ID: 1, Name: "test_query", Query: "SELECT * FROM osquery_info;", EnvironmentID: 1, @@ -171,6 +171,25 @@ func TestCreateNodeQueries(t *testing.T) { }) } +func TestQuerySortableColumnsAllowlist(t *testing.T) { + if _, ok := queries.QuerySortableColumns["unknown"]; ok { + t.Error("unknown should not be allowed") + } + if _, ok := queries.QuerySortableColumns[""]; ok { + t.Error("empty key should not be allowed") + } + if _, ok := queries.QuerySortableColumns["DROP TABLE"]; ok { + t.Error("SQL fragment should not be allowed") + } + // Spot-check what the SPA depends on. + if queries.QuerySortableColumns["name"] != "name" { + t.Error("name → name") + } + if queries.QuerySortableColumns["created"] != "created_at" { + t.Error("created → created_at") + } +} + func TestSetNodeQueriesAsExpired(t *testing.T) { db := testDB(t) q, nodes, query := setupTestData(t, db) diff --git a/pkg/queries/samples.go b/pkg/queries/samples.go new file mode 100644 index 00000000..b522e82f --- /dev/null +++ b/pkg/queries/samples.go @@ -0,0 +1,275 @@ +package queries + +// Starter osquery query samples shipped with osctrl. Used by: +// - GET /api/v1/queries/samples — SPA queries/new form populates its +// QuickTemplates row from this list so new operators have ready-made +// examples to learn from. +// - cmd/cli env add — seeds a SavedQuery row per sample into the new +// environment so the Saves page is not empty out of the box. +// +// Each sample is a pure data record; no database interaction. The list lives +// here (rather than baked into the SPA bundle) so the CLI and the SPA stay +// in sync — both load from the same source. +// +// Editing rules: +// - Names must be unique. The CLI uses Name as the primary key when +// seeding into saved_queries (one-row-per-sample-per-env). +// - SQL must be a single statement and must NOT end in a semicolon — +// the existing query infrastructure appends one and double-semicolons +// break some platforms. +// - Keep platform tags accurate. The SPA filters the templates row by +// selected platforms in the run form; a sample tagged `linux` won't +// appear when an operator has only `windows` selected. + +// QuerySampleCategory is the closed set of category tags. Surfaced in the +// SPA so templates can group; kept as a typed string so a typo at sample-add +// time becomes a compile error. +type QuerySampleCategory string + +const ( + CategoryRecon QuerySampleCategory = "recon" + CategoryProcesses QuerySampleCategory = "processes" + CategoryUsers QuerySampleCategory = "users" + CategoryNetwork QuerySampleCategory = "network" + CategoryPersistence QuerySampleCategory = "persistence" + CategoryFileIntegrity QuerySampleCategory = "file_integrity" + CategoryPackages QuerySampleCategory = "packages" +) + +// QuerySamplePlatform — a platform tag a sample claims to support. Aligns +// with pkg/nodes platform buckets (linux / darwin / windows). A sample +// applicable to every platform tagged with `linux, darwin, windows`. +type QuerySamplePlatform string + +const ( + PlatformLinux QuerySamplePlatform = "linux" + PlatformDarwin QuerySamplePlatform = "darwin" + PlatformWindows QuerySamplePlatform = "windows" +) + +// QuerySample is one starter sample row. +type QuerySample struct { + Name string `json:"name"` + Description string `json:"description"` + SQL string `json:"sql"` + Category QuerySampleCategory `json:"category"` + Platforms []QuerySamplePlatform `json:"platforms"` +} + +// QuerySamples is the canonical starter library. ~20 entries spanning the +// categories above. Operators are expected to read, clone, and adapt these — +// they are intentionally simple and SELECT-only. +// +// Ordering matters: this is the order the SPA template row renders, so the +// most-commonly-useful samples sit first. +var QuerySamples = []QuerySample{ + // ── recon — quick host snapshots ─────────────────────────────────────── + { + Name: "host_overview", + Description: "Hostname, platform, OS version, kernel — basic host identity.", + SQL: "SELECT hostname, computer_name, cpu_brand, physical_memory FROM system_info", + Category: CategoryRecon, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + { + Name: "os_version", + Description: "Operating system name, version, codename, and build identifiers.", + SQL: "SELECT name, version, codename, major, minor, patch, platform, platform_like FROM os_version", + Category: CategoryRecon, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + { + Name: "kernel_info", + Description: "Running kernel name and version.", + SQL: "SELECT name, version FROM kernel_info", + Category: CategoryRecon, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin}, + }, + { + Name: "uptime", + Description: "How long the host has been up — in days, hours, minutes.", + SQL: "SELECT days, hours, minutes, seconds FROM uptime", + Category: CategoryRecon, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + + // ── processes ────────────────────────────────────────────────────────── + { + Name: "running_processes", + Description: "All running processes — pid, name, full path, parent pid.", + SQL: "SELECT pid, name, path, parent FROM processes", + Category: CategoryProcesses, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + { + Name: "processes_root", + Description: "Processes running as root / SYSTEM. Quick way to spot abnormal privileged execution.", + SQL: "SELECT pid, name, path, uid, cmdline FROM processes WHERE uid = 0", + Category: CategoryProcesses, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin}, + }, + { + Name: "processes_no_disk", + Description: "Running processes whose executable on disk is missing — classic injected/memory-only indicator.", + SQL: "SELECT pid, name, path FROM processes WHERE on_disk = 0", + Category: CategoryProcesses, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + + // ── users ────────────────────────────────────────────────────────────── + { + Name: "local_users", + Description: "All local user accounts — username, uid, gid, home directory, shell.", + SQL: "SELECT username, uid, gid, directory, shell FROM users", + Category: CategoryUsers, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + { + Name: "logged_in_users", + Description: "Currently logged-in users with login time and remote host.", + SQL: "SELECT user, host, time, tty, type FROM logged_in_users", + Category: CategoryUsers, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + { + Name: "sudoers_groups", + Description: "Group memberships — useful for spotting unexpected sudo / wheel / admin members.", + SQL: "SELECT username, groupname FROM users JOIN user_groups USING(uid) JOIN groups USING(gid)", + Category: CategoryUsers, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin}, + }, + + // ── network ──────────────────────────────────────────────────────────── + { + Name: "listening_ports", + Description: "TCP/UDP listeners with the binding process and PID.", + SQL: "SELECT pid, port, protocol, address, p.name AS process FROM listening_ports l JOIN processes p USING(pid)", + Category: CategoryNetwork, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + { + Name: "active_connections", + Description: "Established outbound TCP connections — remote IP and port.", + SQL: "SELECT pid, local_address, local_port, remote_address, remote_port FROM process_open_sockets WHERE state = 'ESTABLISHED'", + Category: CategoryNetwork, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + { + Name: "arp_cache", + Description: "ARP cache entries — recently-seen MAC↔IP pairs on the LAN.", + SQL: "SELECT address, mac, interface FROM arp_cache", + Category: CategoryNetwork, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + { + Name: "interface_addresses", + Description: "All network-interface addresses with subnet masks and broadcast addresses.", + SQL: "SELECT interface, address, mask, broadcast FROM interface_addresses", + Category: CategoryNetwork, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + + // ── persistence ──────────────────────────────────────────────────────── + { + Name: "crontab_all", + Description: "Every cron job on the host across system and per-user crontabs.", + SQL: "SELECT command, path, minute, hour, day_of_month, month, day_of_week FROM crontab", + Category: CategoryPersistence, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin}, + }, + { + Name: "systemd_units", + Description: "Loaded systemd units — name, state, file path. Look for unfamiliar service files.", + SQL: "SELECT id, fragment_path, active_state, sub_state, unit_file_state FROM systemd_units", + Category: CategoryPersistence, + Platforms: []QuerySamplePlatform{PlatformLinux}, + }, + { + Name: "launchd_overview", + Description: "macOS launchd jobs — daemons and agents loaded at boot/login.", + SQL: "SELECT name, path, program, run_at_load, keep_alive, disabled FROM launchd", + Category: CategoryPersistence, + Platforms: []QuerySamplePlatform{PlatformDarwin}, + }, + { + Name: "startup_items", + Description: "Windows autostart entries — Run/RunOnce registry keys and Startup folders.", + SQL: "SELECT name, path, source, status, type FROM startup_items", + Category: CategoryPersistence, + Platforms: []QuerySamplePlatform{PlatformWindows}, + }, + { + Name: "scheduled_tasks_windows", + Description: "Windows Task Scheduler jobs — name, action, last_run_time, enabled state.", + SQL: "SELECT name, action, path, enabled, last_run_time, next_run_time FROM scheduled_tasks", + Category: CategoryPersistence, + Platforms: []QuerySamplePlatform{PlatformWindows}, + }, + { + Name: "services_windows", + Description: "Windows services — name, display_name, start_type, status, path on disk.", + SQL: "SELECT name, display_name, status, start_type, path FROM services", + Category: CategoryPersistence, + Platforms: []QuerySamplePlatform{PlatformWindows}, + }, + + // ── file integrity ───────────────────────────────────────────────────── + { + Name: "etc_passwd", + Description: "Hash, size, owner, permissions of /etc/passwd — classic file-integrity check.", + SQL: "SELECT path, size, mode, uid, gid, mtime, sha256 FROM file WHERE path = '/etc/passwd'", + Category: CategoryFileIntegrity, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin}, + }, + { + Name: "etc_hosts_contents", + Description: "Lines of /etc/hosts — quick way to spot tampering or DNS-override mischief.", + SQL: "SELECT address, hostnames FROM etc_hosts", + Category: CategoryFileIntegrity, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + { + Name: "windows_hosts_file", + Description: "Hash and metadata of the Windows hosts file — should rarely change in a managed fleet.", + SQL: "SELECT path, size, mtime, sha256 FROM file WHERE path = 'C:\\Windows\\System32\\drivers\\etc\\hosts'", + Category: CategoryFileIntegrity, + Platforms: []QuerySamplePlatform{PlatformWindows}, + }, + { + Name: "certificates_trusted", + Description: "Trusted certificates in the system store — recent additions can indicate MITM CA installs.", + SQL: "SELECT common_name, subject, issuer, not_valid_after, sha1 FROM certificates", + Category: CategoryFileIntegrity, + Platforms: []QuerySamplePlatform{PlatformLinux, PlatformDarwin, PlatformWindows}, + }, + + // ── packages / installed software ────────────────────────────────────── + { + Name: "installed_packages_deb", + Description: "Debian / Ubuntu installed packages with version.", + SQL: "SELECT name, version, arch FROM deb_packages", + Category: CategoryPackages, + Platforms: []QuerySamplePlatform{PlatformLinux}, + }, + { + Name: "installed_packages_rpm", + Description: "RHEL / Fedora / CentOS installed RPM packages with version.", + SQL: "SELECT name, version, arch FROM rpm_packages", + Category: CategoryPackages, + Platforms: []QuerySamplePlatform{PlatformLinux}, + }, + { + Name: "installed_apps_macos", + Description: "macOS .app bundles in /Applications — name, version, bundle id.", + SQL: "SELECT name, bundle_identifier, bundle_short_version FROM apps", + Category: CategoryPackages, + Platforms: []QuerySamplePlatform{PlatformDarwin}, + }, + { + Name: "installed_programs_windows", + Description: "Windows installed programs — name, version, publisher, install_date.", + SQL: "SELECT name, version, publisher, install_date FROM programs", + Category: CategoryPackages, + Platforms: []QuerySamplePlatform{PlatformWindows}, + }, +} diff --git a/pkg/queries/saved.go b/pkg/queries/saved.go index 097ee5a2..32855058 100644 --- a/pkg/queries/saved.go +++ b/pkg/queries/saved.go @@ -1,21 +1,45 @@ package queries import ( + "errors" "fmt" + "strings" "gorm.io/gorm" ) -// SavedQuery as abstraction of a saved query to be used in distributed, schedule or packs +// SavedQuery as abstraction of a saved query to be used in distributed, schedule or packs. +// +// Composite unique index on (name, environment_id) — gorm AutoMigrate emits +// it as `idx_saved_query_name_env`. This is the structural fix for the +// TOCTOU race in SavedQueryCreateHandler: a concurrent pair of POSTs with +// the same name + env both pass the SavedExists precheck, both attempt +// CreateSaved; with the unique index, the second Create returns a +// duplicate-key error and the handler can map it to 409 cleanly. type SavedQuery struct { gorm.Model - Name string + Name string `gorm:"uniqueIndex:idx_saved_query_name_env"` Creator string Query string - EnvironmentID uint + EnvironmentID uint `gorm:"uniqueIndex:idx_saved_query_name_env"` ExtraData string } +// SavedQueryListPage is the canonical paginated-list result for saved queries. +type SavedQueryListPage struct { + Items []SavedQuery + TotalItems int64 +} + +// SavedQuerySortableColumns is the closed set of columns external callers may +// sort by. Enforced in GetSavedByEnvPaged. Mirrors QuerySortableColumns. +var SavedQuerySortableColumns = map[string]string{ + "name": "name", + "creator": "creator", + "created": "created_at", + "updated": "updated_at", +} + // GetSavedByCreator to get a saved query by creator func (q *Queries) GetSavedByCreator(creator string, envid uint) ([]SavedQuery, error) { var saved []SavedQuery @@ -25,16 +49,91 @@ func (q *Queries) GetSavedByCreator(creator string, envid uint) ([]SavedQuery, e return saved, nil } -// GetSaved to get a saved query by creator +// GetSaved to get a saved query by name + creator within an environment. +// Returns gorm.ErrRecordNotFound when no matching row exists — callers can +// use errors.Is(err, gorm.ErrRecordNotFound) to detect that case. func (q *Queries) GetSaved(name, creator string, envid uint) (SavedQuery, error) { var saved SavedQuery - if err := q.DB.Where("creator = ? AND name = ? AND environment_id = ?", creator, name, envid).Find(&saved).Error; err != nil { + if err := q.DB.Where("creator = ? AND name = ? AND environment_id = ?", creator, name, envid).First(&saved).Error; err != nil { return saved, err } return saved, nil } -// CreateSaved to create new saved query +// GetSavedByEnv returns a saved query by name within an environment without +// scoping by creator — used by env admins who can manage any saved query. +// Returns gorm.ErrRecordNotFound when no matching row exists. +func (q *Queries) GetSavedByEnv(name string, envid uint) (SavedQuery, error) { + var saved SavedQuery + if err := q.DB.Where("name = ? AND environment_id = ?", name, envid).First(&saved).Error; err != nil { + return saved, err + } + return saved, nil +} + +// SavedExists reports whether a saved query with the given name exists in the +// environment, irrespective of creator. +func (q *Queries) SavedExists(name string, envid uint) bool { + var count int64 + if err := q.DB.Model(&SavedQuery{}).Where("name = ? AND environment_id = ?", name, envid).Count(&count).Error; err != nil { + return false + } + return count > 0 +} + +// GetSavedByEnvPaged returns a page of saved queries for an env, with optional +// free-text search and an allowlisted sort column. pageSize is clamped to +// [1, 500]; pageSize <= 0 defaults to 50. page is 1-indexed. +func (q *Queries) GetSavedByEnvPaged(envid uint, search string, page, pageSize int, sortColumn string, desc bool) (SavedQueryListPage, error) { + if pageSize <= 0 { + pageSize = 50 + } + if pageSize > 500 { + pageSize = 500 + } + if page <= 0 { + page = 1 + } + offset := (page - 1) * pageSize + + dbCol, ok := SavedQuerySortableColumns[sortColumn] + if !ok || sortColumn == "" { + dbCol = "created_at" + desc = true + } + dir := "ASC" + if desc { + dir = "DESC" + } + orderExpr := fmt.Sprintf("%s %s", dbCol, dir) + + db := q.DB.Model(&SavedQuery{}).Where("environment_id = ?", envid) + if search != "" { + like := "%" + search + "%" + db = db.Where("name LIKE ? OR creator LIKE ? OR query LIKE ?", like, like, like) + } + + var total int64 + if err := db.Count(&total).Error; err != nil { + return SavedQueryListPage{}, err + } + var items []SavedQuery + if err := db.Order(orderExpr).Offset(offset).Limit(pageSize).Find(&items).Error; err != nil { + return SavedQueryListPage{}, err + } + return SavedQueryListPage{Items: items, TotalItems: total}, nil +} + +// ErrSavedQueryExists is returned by CreateSaved when the underlying +// unique index on (name, environment_id) rejects the insert because a +// row with the same key already exists. Callers should map this to a +// 409 Conflict response. +var ErrSavedQueryExists = errors.New("saved query already exists") + +// CreateSaved persists a new saved query. Returns ErrSavedQueryExists +// when a row with the same (name, env) already exists — the DB unique +// index `idx_saved_query_name_env` is the authoritative gate, so the +// handler does not need to win the SavedExists race anymore. func (q *Queries) CreateSaved(name, query, creator string, envid uint) error { saved := SavedQuery{ Name: name, @@ -43,32 +142,63 @@ func (q *Queries) CreateSaved(name, query, creator string, envid uint) error { EnvironmentID: envid, } if err := q.DB.Create(&saved).Error; err != nil { + if errors.Is(err, gorm.ErrDuplicatedKey) { + return ErrSavedQueryExists + } + // PG / MySQL drivers may bubble up the driver-specific dup-key + // error rather than gorm.ErrDuplicatedKey on some versions — + // fall back to a string match for the well-known sentinels so + // the handler still gets a clean 409 path. + es := err.Error() + if strings.Contains(es, "duplicate key") || strings.Contains(es, "Duplicate entry") || strings.Contains(es, "UNIQUE constraint") { + return ErrSavedQueryExists + } return err } return nil } -// UpdateSaved to update an existing saved query -func (q *Queries) UpdateSaved(name, query, creator string, envid uint) error { - saved, err := q.GetSaved(name, creator, envid) +// UpdateSaved updates the SQL body of an existing saved query identified by +// (name, env). The creator field is not modified — original ownership stays. +// Returns gorm.ErrRecordNotFound when the row does not exist. +func (q *Queries) UpdateSaved(name, query string, envid uint) error { + saved, err := q.GetSavedByEnv(name, envid) if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return err + } return fmt.Errorf("error getting saved query %w", err) } - data := SavedQuery{ - Name: name, - Query: query, - EnvironmentID: envid, - } - if err := q.DB.Model(&saved).Updates(data).Error; err != nil { + if err := q.DB.Model(&saved).Update("query", query).Error; err != nil { return fmt.Errorf("in Updates %w", err) } return nil } -// DeleteSaved to delete an existing saved query +// DeleteSavedByEnv removes a saved query by name within an environment. +// Returns gorm.ErrRecordNotFound when nothing matched. +func (q *Queries) DeleteSavedByEnv(name string, envid uint) error { + saved, err := q.GetSavedByEnv(name, envid) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + return fmt.Errorf("error getting saved query %w", err) + } + if err := q.DB.Unscoped().Delete(&saved).Error; err != nil { + return fmt.Errorf("in DeleteSaved %w", err) + } + return nil +} + +// DeleteSaved removes a saved query owned by (creator, env, name). +// Retained for backward compatibility with non-API callers. func (q *Queries) DeleteSaved(name, creator string, envid uint) error { saved, err := q.GetSaved(name, creator, envid) if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return err + } return fmt.Errorf("error getting saved query %w", err) } if err := q.DB.Unscoped().Delete(&saved).Error; err != nil { diff --git a/pkg/queries/saved_test.go b/pkg/queries/saved_test.go new file mode 100644 index 00000000..18d029cd --- /dev/null +++ b/pkg/queries/saved_test.go @@ -0,0 +1,125 @@ +package queries_test + +import ( + "errors" + "testing" + + "github.com/jmpsec/osctrl/pkg/queries" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +// TestSavedQuerySortableColumns asserts the allowlist is closed and maps each +// API-facing key onto an actual storage column. The map is consulted from +// GetSavedByEnvPaged before any ORDER BY expression is built; if this drift +// allowlist drifts the API stops accepting that sort key (which is the right +// behavior — we don't want to add a column the package can't translate). +func TestSavedQuerySortableColumns(t *testing.T) { + want := map[string]string{ + "name": "name", + "creator": "creator", + "created": "created_at", + "updated": "updated_at", + } + assert.Equal(t, want, queries.SavedQuerySortableColumns) +} + +func TestSavedQueryCRUD(t *testing.T) { + db := testDB(t) + q := queries.CreateQueries(db) + + // Create + require.NoError(t, q.CreateSaved("first", "SELECT 1", "alice", 1)) + require.True(t, q.SavedExists("first", 1)) + require.False(t, q.SavedExists("first", 2)) // different env, still false + + // Duplicate in same env detected via SavedExists (handler enforces 409) + require.True(t, q.SavedExists("first", 1)) + + // GetSavedByEnv returns the row regardless of creator + got, err := q.GetSavedByEnv("first", 1) + require.NoError(t, err) + assert.Equal(t, "first", got.Name) + assert.Equal(t, "alice", got.Creator) + assert.Equal(t, "SELECT 1", got.Query) + + // GetSaved (creator-scoped) — same creator wins + got2, err := q.GetSaved("first", "alice", 1) + require.NoError(t, err) + assert.Equal(t, got.ID, got2.ID) + + // GetSaved with the wrong creator returns ErrRecordNotFound (not a zero row) + _, err = q.GetSaved("first", "bob", 1) + require.Error(t, err) + assert.True(t, errors.Is(err, gorm.ErrRecordNotFound)) + + // Update preserves creator + require.NoError(t, q.UpdateSaved("first", "SELECT 2", 1)) + updated, err := q.GetSavedByEnv("first", 1) + require.NoError(t, err) + assert.Equal(t, "SELECT 2", updated.Query) + assert.Equal(t, "alice", updated.Creator, "update must not overwrite creator") + + // Delete by env + require.NoError(t, q.DeleteSavedByEnv("first", 1)) + assert.False(t, q.SavedExists("first", 1)) + + // Deleting again surfaces ErrRecordNotFound + err = q.DeleteSavedByEnv("first", 1) + require.Error(t, err) + assert.True(t, errors.Is(err, gorm.ErrRecordNotFound)) +} + +func TestGetSavedByEnvPaged(t *testing.T) { + db := testDB(t) + q := queries.CreateQueries(db) + + // Seed across two envs to verify env scoping + require.NoError(t, q.CreateSaved("alpha", "SELECT a", "alice", 1)) + require.NoError(t, q.CreateSaved("beta", "SELECT b", "alice", 1)) + require.NoError(t, q.CreateSaved("gamma", "SELECT c", "bob", 1)) + require.NoError(t, q.CreateSaved("other_env", "SELECT z", "alice", 2)) + + // Default sort = created_at DESC, env 1 + page, err := q.GetSavedByEnvPaged(1, "", 0, 0, "", false) + require.NoError(t, err) + assert.Equal(t, int64(3), page.TotalItems, "env scoping leaks if this is != 3") + require.Len(t, page.Items, 3) + assert.Equal(t, "gamma", page.Items[0].Name, "newest first by default") + + // Search narrows to one row + page, err = q.GetSavedByEnvPaged(1, "alph", 0, 0, "", false) + require.NoError(t, err) + assert.Equal(t, int64(1), page.TotalItems) + require.Len(t, page.Items, 1) + assert.Equal(t, "alpha", page.Items[0].Name) + + // Sort by name asc + page, err = q.GetSavedByEnvPaged(1, "", 0, 0, "name", false) + require.NoError(t, err) + require.Len(t, page.Items, 3) + assert.Equal(t, []string{"alpha", "beta", "gamma"}, []string{ + page.Items[0].Name, page.Items[1].Name, page.Items[2].Name, + }) + + // Pagination — page_size 2, page 1 of 2 + page, err = q.GetSavedByEnvPaged(1, "", 1, 2, "name", false) + require.NoError(t, err) + require.Len(t, page.Items, 2) + assert.Equal(t, []string{"alpha", "beta"}, []string{ + page.Items[0].Name, page.Items[1].Name, + }) + assert.Equal(t, int64(3), page.TotalItems) + + // Pagination — page 2 + page, err = q.GetSavedByEnvPaged(1, "", 2, 2, "name", false) + require.NoError(t, err) + require.Len(t, page.Items, 1) + assert.Equal(t, "gamma", page.Items[0].Name) + + // Unknown sort key falls back to created_at DESC + page, err = q.GetSavedByEnvPaged(1, "", 0, 0, "DROP TABLE", false) + require.NoError(t, err, "unknown sort key must fall back, never inject") + require.Len(t, page.Items, 3) +} diff --git a/pkg/tags/tags.go b/pkg/tags/tags.go index 542969a2..88b82fdb 100644 --- a/pkg/tags/tags.go +++ b/pkg/tags/tags.go @@ -3,6 +3,7 @@ package tags import ( "fmt" "strings" + "time" "github.com/jmpsec/osctrl/pkg/nodes" "github.com/rs/zerolog/log" @@ -46,19 +47,26 @@ const ( TagCustomTag string = TagTypeTagStr ) -// AdminTag to hold all tags +// AdminTag to hold all tags. +// +// Explicit JSON tags so /api/v1/tags responses match the SPA's snake_case +// contract. Fields are equivalent to embedding gorm.Model; we expand them +// so we can attach json tags to ID/CreatedAt/UpdatedAt/DeletedAt. type AdminTag struct { - gorm.Model - Name string `gorm:"index"` - Description string - Color string - Icon string - CreatedBy string - CustomTag string - AutoTag bool - EnvironmentID uint - TagType uint - Cohort bool + ID uint `gorm:"primarykey" json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` + Name string `gorm:"index" json:"name"` + Description string `json:"description"` + Color string `json:"color"` + Icon string `json:"icon"` + CreatedBy string `json:"created_by"` + CustomTag string `json:"custom_tag"` + AutoTag bool `json:"auto_tag"` + EnvironmentID uint `json:"environment_id"` + TagType uint `json:"tag_type"` + Cohort bool `json:"cohort"` } // AdminTagForNode to check if this tag is used for an specific node diff --git a/pkg/types/node_view.go b/pkg/types/node_view.go new file mode 100644 index 00000000..fc313555 --- /dev/null +++ b/pkg/types/node_view.go @@ -0,0 +1,199 @@ +package types + +import ( + "encoding/json" + + "github.com/jmpsec/osctrl/pkg/nodes" +) + +// SPA-facing node projections that surface the parsed-and-sanitized subset of +// nodes.OsqueryNode.RawEnrollment (the JSON blob osquery sends during enroll). +// RawEnrollment itself stays `json:"-"` on the DB model because it contains the +// env's enroll_secret. Everything below is the safe-to-expose subset. +// +// Why a separate projection rather than adding JSON tags to RawEnrollment: +// - Selective exposure: the enroll payload includes `enroll_secret`; we MUST +// drop it. Surface-by-surface field allowlisting is safer than blacklisting +// a single key on a `map[string]interface{}`. +// - Versioning: osquery's enrollment payload is osquery-side schema, not +// osctrl-side. If a future osquery release adds a field, we don't leak it +// until we explicitly add it here. +// - Backward compat: existing API consumers see exactly the same OsqueryNode +// shape they always did — `system_info` is an *additional* field with +// `omitempty`, so when parsing fails or the node has no raw enrollment it +// simply disappears. + +// SystemInfo mirrors host_details.system_info from the osquery enroll payload, +// minus the host_identifier / instance_id fields which are duplicates of data +// we already expose via OsqueryNode.UUID. +type SystemInfo struct { + HardwareVendor string `json:"hardware_vendor,omitempty"` + HardwareModel string `json:"hardware_model,omitempty"` + HardwareVersion string `json:"hardware_version,omitempty"` + HardwareSerial string `json:"hardware_serial,omitempty"` + CPUBrand string `json:"cpu_brand,omitempty"` + CPUType string `json:"cpu_type,omitempty"` + CPUSubtype string `json:"cpu_subtype,omitempty"` + CPUPhysicalCores string `json:"cpu_physical_cores,omitempty"` + CPULogicalCores string `json:"cpu_logical_cores,omitempty"` + PhysicalMemory string `json:"physical_memory,omitempty"` + ComputerName string `json:"computer_name,omitempty"` + LocalHostname string `json:"local_hostname,omitempty"` +} + +// BIOSInfo mirrors host_details.platform_info from the osquery enroll payload. +// "Platform info" in osquery's vocabulary is BIOS / firmware metadata; renamed +// here so the SPA naming aligns with what an operator expects to read. +type BIOSInfo struct { + Vendor string `json:"vendor,omitempty"` + Version string `json:"version,omitempty"` + Date string `json:"date,omitempty"` + Revision string `json:"revision,omitempty"` + Address string `json:"address,omitempty"` + Size string `json:"size,omitempty"` + VolumeSize string `json:"volume_size,omitempty"` +} + +// OSInfo mirrors host_details.os_version. Adds the few fields beyond what +// OsqueryNode.Platform / PlatformVersion already expose (codename, family). +type OSInfo struct { + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` + Codename string `json:"codename,omitempty"` + Major string `json:"major,omitempty"` + Minor string `json:"minor,omitempty"` + Patch string `json:"patch,omitempty"` + Platform string `json:"platform,omitempty"` + PlatformLike string `json:"platform_like,omitempty"` +} + +// OsqueryRuntime mirrors host_details.osquery_info — the runtime / build +// metadata of the agent that enrolled. Useful for "this node is running an +// extensions-disabled build" diagnostics. Drops `instance_id`, `pid`, and +// `watcher` (PIDs) since they leak less-useful runtime detail; keep +// `start_time` so operators can see when the daemon last restarted. +type OsqueryRuntime struct { + Version string `json:"version,omitempty"` + BuildPlatform string `json:"build_platform,omitempty"` + BuildDistro string `json:"build_distro,omitempty"` + Extensions string `json:"extensions,omitempty"` + StartTime string `json:"start_time,omitempty"` + ConfigValid string `json:"config_valid,omitempty"` +} + +// NodeEnrichment is the projected view of everything we want to expose from +// nodes.OsqueryNode.RawEnrollment that isn't already on OsqueryNode itself. +// Embedded into NodeView with `json:"system_info,omitempty"` — the outer key +// is a slight abuse of the name (it carries BIOS + OS + runtime too) but it +// matches the heaviest sub-object and reads well in the SPA. +type NodeEnrichment struct { + System *SystemInfo `json:"system,omitempty"` + BIOS *BIOSInfo `json:"bios,omitempty"` + OS *OSInfo `json:"os,omitempty"` + Osquery *OsqueryRuntime `json:"osquery,omitempty"` +} + +// NodeView is the JSON shape returned by the node show + list endpoints. +// It embeds OsqueryNode verbatim (so existing JSON fields stay) and adds the +// optional enrichment block. Consumers that don't care about the enrichment +// (CLI, dashboards) ignore the extra field; the SPA's Node Detail page reads +// from it directly. +type NodeView struct { + nodes.OsqueryNode + Enrichment *NodeEnrichment `json:"system_info,omitempty"` +} + +// ProjectNode wraps a single OsqueryNode into the SPA-facing NodeView, parsing +// RawEnrollment best-effort. A parse failure or an absent payload simply +// leaves Enrichment nil — the JSON `omitempty` then drops the key entirely so +// the SPA sees the same `OsqueryNode` shape it always saw, plus optional +// detail when available. +func ProjectNode(n nodes.OsqueryNode) NodeView { + view := NodeView{OsqueryNode: n} + if n.RawEnrollment == "" { + return view + } + // Parse into an intermediate map-of-maps because osquery's enroll payload + // shape is osquery-side and we don't want to maintain a parallel Go struct + // for every key. We only read the few keys we need. + var outer struct { + HostDetails struct { + SystemInfo map[string]string `json:"system_info"` + PlatformInfo map[string]string `json:"platform_info"` + OSVersion map[string]string `json:"os_version"` + OsqueryInfo map[string]string `json:"osquery_info"` + } `json:"host_details"` + } + if err := json.Unmarshal([]byte(n.RawEnrollment), &outer); err != nil { + // Malformed payload — return the bare node, don't fail the request. + return view + } + enr := &NodeEnrichment{} + if si := outer.HostDetails.SystemInfo; len(si) > 0 { + enr.System = &SystemInfo{ + HardwareVendor: si["hardware_vendor"], + HardwareModel: si["hardware_model"], + HardwareVersion: si["hardware_version"], + HardwareSerial: si["hardware_serial"], + CPUBrand: si["cpu_brand"], + CPUType: si["cpu_type"], + CPUSubtype: si["cpu_subtype"], + CPUPhysicalCores: si["cpu_physical_cores"], + CPULogicalCores: si["cpu_logical_cores"], + PhysicalMemory: si["physical_memory"], + ComputerName: si["computer_name"], + LocalHostname: si["local_hostname"], + } + } + if pi := outer.HostDetails.PlatformInfo; len(pi) > 0 { + enr.BIOS = &BIOSInfo{ + Vendor: pi["vendor"], + Version: pi["version"], + Date: pi["date"], + Revision: pi["revision"], + Address: pi["address"], + Size: pi["size"], + VolumeSize: pi["volume_size"], + } + } + if ov := outer.HostDetails.OSVersion; len(ov) > 0 { + enr.OS = &OSInfo{ + Name: ov["name"], + Version: ov["version"], + Codename: ov["codename"], + Major: ov["major"], + Minor: ov["minor"], + Patch: ov["patch"], + Platform: ov["platform"], + PlatformLike: ov["platform_like"], + } + } + if oi := outer.HostDetails.OsqueryInfo; len(oi) > 0 { + enr.Osquery = &OsqueryRuntime{ + Version: oi["version"], + BuildPlatform: oi["build_platform"], + BuildDistro: oi["build_distro"], + Extensions: oi["extensions"], + StartTime: oi["start_time"], + ConfigValid: oi["config_valid"], + } + } + // Drop the enrichment block entirely when nothing was populated, so that a + // node with empty/whitespace RawEnrollment doesn't leak a "system_info: {}" + // shell that misleads operators into thinking we have data we don't. + if enr.System == nil && enr.BIOS == nil && enr.OS == nil && enr.Osquery == nil { + return view + } + view.Enrichment = enr + return view +} + +// ProjectNodes wraps a slice with ProjectNode — used by the list endpoint to +// keep the table-row payload consistent with the show endpoint. +func ProjectNodes(in []nodes.OsqueryNode) []NodeView { + out := make([]NodeView, len(in)) + for i, n := range in { + out[i] = ProjectNode(n) + } + return out +} diff --git a/pkg/types/types.go b/pkg/types/types.go index 2536441a..532279e7 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -1,6 +1,10 @@ package types -import "time" +import ( + "time" + + "github.com/jmpsec/osctrl/pkg/queries" +) // OsqueryTable to show tables to query type OsqueryTable struct { @@ -84,6 +88,14 @@ type ApiLoginRequest struct { ExpHours int `json:"exp_hours"` } +// LoginEnvironment is the pre-auth-safe projection of an environment returned +// by GET /api/v1/login/environments. UUID + name only — every other field +// stays behind auth. +type LoginEnvironment struct { + UUID string `json:"uuid"` + Name string `json:"name"` +} + // ApiErrorResponse to be returned to API requests with the error message type ApiErrorResponse struct { Error string `json:"error"` @@ -160,6 +172,274 @@ type ApiUserRequest struct { Environments []string `json:"environments"` } +// NodesPagedResponse is the SPA-canonical paginated response for GET /api/v1/nodes/{env}. +// Items are NodeView — OsqueryNode plus the optional `system_info` enrichment +// block (CPU cores, BIOS, hardware vendor/model) parsed from RawEnrollment. +// The embed keeps every previous OsqueryNode JSON field at the same key, so +// existing consumers (CLI, dashboards) are unaffected. +type NodesPagedResponse struct { + Items []NodeView `json:"items"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalItems int64 `json:"total_items"` + TotalPages int `json:"total_pages"` +} + +// QueriesPagedResponse is the SPA-canonical paginated response for +// GET /api/v1/queries/{env}/list/{target}. +type QueriesPagedResponse struct { + Items []queries.DistributedQuery `json:"items"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalItems int64 `json:"total_items"` + TotalPages int `json:"total_pages"` +} + +// QueryResultsResponse is the SPA-canonical paginated response for +// GET /api/v1/queries/{env}/results/{name}. +type QueryResultsResponse struct { + Items []map[string]any `json:"items"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalItems int64 `json:"total_items"` + TotalPages int `json:"total_pages"` + Since string `json:"since,omitempty"` +} + +// SavedQueryView is the SPA-canonical projection of a saved query. +// We use a hand-typed struct (rather than queries.SavedQuery directly) so the +// JSON envelope stays stable even if the storage struct gains fields. +// Timestamps are emitted as RFC3339 (Go time.Time default JSON encoding), to +// match the OpenAPI schema (date-time) and the SPA's formatRelative parser. +type SavedQueryView struct { + ID uint `json:"id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Name string `json:"name"` + Creator string `json:"creator"` + Query string `json:"query"` + EnvironmentID uint `json:"environment_id"` + ExtraData string `json:"extra_data,omitempty"` +} + +// SavedQueriesPagedResponse is the SPA-canonical paginated response for +// GET /api/v1/saved-queries/{env}. +type SavedQueriesPagedResponse struct { + Items []SavedQueryView `json:"items"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalItems int64 `json:"total_items"` + TotalPages int `json:"total_pages"` +} + +// SavedQueryCreateRequest is the body shape for POST /api/v1/saved-queries/{env}. +type SavedQueryCreateRequest struct { + Name string `json:"name"` + Query string `json:"query"` +} + +// SavedQueryUpdateRequest is the body shape for PATCH /api/v1/saved-queries/{env}/{name}. +type SavedQueryUpdateRequest struct { + Query string `json:"query"` +} + +// CarvesPagedResponse is the SPA-canonical paginated response for +// GET /api/v1/carves/{env}. Items are carve-type DistributedQuery rows +// (one per carve operation, regardless of how many nodes the carve targeted). +type CarvesPagedResponse struct { + Items []queries.DistributedQuery `json:"items"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalItems int64 `json:"total_items"` + TotalPages int `json:"total_pages"` +} + +// CarveFileView is the SPA-canonical projection of a single carved file +// row (one per node that completed the carve). Timestamps are RFC3339 so +// the SPA's formatRelative parser handles them; CarveID is the disambiguator +// when downloading the archive of a multi-node carve. +type CarveFileView struct { + CarveID string `json:"carve_id"` + SessionID string `json:"session_id"` + UUID string `json:"uuid"` + Path string `json:"path"` + Status string `json:"status"` + CarveSize int `json:"carve_size"` + BlockSize int `json:"block_size"` + TotalBlocks int `json:"total_blocks"` + CompletedBlocks int `json:"completed_blocks"` + Archived bool `json:"archived"` + CreatedAt time.Time `json:"created_at"` + CompletedAt time.Time `json:"completed_at"` +} + +// CarveDetailResponse is the SPA-canonical response for +// GET /api/v1/carves/{env}/{name}. It pairs the carve QUERY metadata with +// the per-node CarvedFile rows produced by the carve. +type CarveDetailResponse struct { + Query queries.DistributedQuery `json:"query"` + Files []CarveFileView `json:"files"` +} + +// EnvAccessView mirrors users.EnvAccess but lives in the types package so +// the API request/response shapes don't pull in pkg/users for SPA-side codegen. +type EnvAccessView struct { + User bool `json:"user"` + Query bool `json:"query"` + Carve bool `json:"carve"` + Admin bool `json:"admin"` +} + +// SetPermissionsRequest is the body for POST /api/v1/users/{username}/permissions. +type SetPermissionsRequest struct { + EnvUUID string `json:"env_uuid"` + Access EnvAccessView `json:"access"` +} + +// TokenResponse is returned by POST /api/v1/users/{username}/token/refresh +// and by login. The Token is shown ONCE to the operator (so they can copy it +// for CLI use); it isn't returned by any GET endpoint after refresh. +type TokenResponse struct { + Token string `json:"token"` + Expires time.Time `json:"expires"` +} + +// UserMeResponse is the SPA-canonical projection of the currently-authenticated +// user. Used by GET /api/v1/users/me. +type UserMeResponse struct { + Username string `json:"username"` + Email string `json:"email"` + Fullname string `json:"fullname"` + Admin bool `json:"admin"` + Service bool `json:"service"` + UUID string `json:"uuid"` + TokenExpire time.Time `json:"token_expire"` + LastAccess time.Time `json:"last_access"` +} + +// UserMePatchRequest is the body for PATCH /api/v1/users/me — operators can +// update their own profile (email and fullname only). +type UserMePatchRequest struct { + Email string `json:"email"` + Fullname string `json:"fullname"` +} + +// PasswordChangeRequest is the body for POST /api/v1/users/me/password. +type PasswordChangeRequest struct { + CurrentPassword string `json:"current_password"` + NewPassword string `json:"new_password"` +} + +// --------------------------------------------------------------------------- +// Environments (Track 8) +// --------------------------------------------------------------------------- + +// EnvCreateRequest is the body for POST /api/v1/environments. +type EnvCreateRequest struct { + Name string `json:"name"` + Hostname string `json:"hostname"` + Type string `json:"type,omitempty"` + Icon string `json:"icon,omitempty"` +} + +// EnvUpdateRequest is the body for PATCH /api/v1/environments/{env}. +// Pointer fields distinguish "unset" from "set to empty"; only supplied +// fields are written. +type EnvUpdateRequest struct { + Name *string `json:"name,omitempty"` + Hostname *string `json:"hostname,omitempty"` + Type *string `json:"type,omitempty"` + Icon *string `json:"icon,omitempty"` + DebugHTTP *bool `json:"debug_http,omitempty"` + AcceptEnrolls *bool `json:"accept_enrolls,omitempty"` +} + +// EnvConfigResponse is the GET /api/v1/environments/config/{env} payload — +// each field is the raw JSON string for that osquery config section so the +// SPA's Monaco editor can render and edit it as-is. +type EnvConfigResponse struct { + Options string `json:"options"` + Schedule string `json:"schedule"` + Packs string `json:"packs"` + Decorators string `json:"decorators"` + ATC string `json:"atc"` + Flags string `json:"flags"` +} + +// EnvConfigPatchRequest is the body for PATCH /api/v1/environments/config/{env}. +// Pointer fields: nil means "leave this section alone", non-nil writes it. +// Each non-nil value is JSON-validated before persisting; the handler rejects +// the whole payload if any section is invalid (no partial writes). +type EnvConfigPatchRequest struct { + Options *string `json:"options,omitempty"` + Schedule *string `json:"schedule,omitempty"` + Packs *string `json:"packs,omitempty"` + Decorators *string `json:"decorators,omitempty"` + ATC *string `json:"atc,omitempty"` + Flags *string `json:"flags,omitempty"` +} + +// EnvIntervalsPatchRequest is the body for PATCH /api/v1/environments/intervals/{env}. +// Each interval is in seconds; pointer semantics same as EnvConfigPatchRequest. +type EnvIntervalsPatchRequest struct { + ConfigInterval *int `json:"config_interval,omitempty"` + LogInterval *int `json:"log_interval,omitempty"` + QueryInterval *int `json:"query_interval,omitempty"` +} + +// EnvExpirationPatchRequest is the body for PATCH /api/v1/environments/expiration/{env}. +// Action is one of: extend, expire, rotate, not-expire. +type EnvExpirationPatchRequest struct { + Action string `json:"action"` +} + +// --------------------------------------------------------------------------- +// Settings (Track 9) +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// Audit log (Track 10) +// --------------------------------------------------------------------------- + +// AuditLogView is the SPA-canonical projection of one pkg/auditlog.AuditLog row. +// We use a hand-typed struct (rather than the storage struct directly) so the +// JSON envelope stays stable as the storage shape evolves. Timestamps are +// RFC3339 to match SavedQueryView / CarveFileView and the SPA's formatRelative +// parser. +type AuditLogView struct { + ID uint `json:"id"` + CreatedAt time.Time `json:"created_at"` + Service string `json:"service"` + Username string `json:"username"` + Line string `json:"line"` + LogType uint `json:"log_type"` + Severity uint `json:"severity"` + SourceIP string `json:"source_ip"` + EnvironmentID uint `json:"environment_id"` + EnvUUID string `json:"env_uuid,omitempty"` +} + +// AuditLogsPagedResponse is the SPA-canonical paginated response for +// GET /api/v1/audit-logs. +type AuditLogsPagedResponse struct { + Items []AuditLogView `json:"items"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalItems int64 `json:"total_items"` + TotalPages int `json:"total_pages"` +} + +// SettingPatchRequest is the body for PATCH /api/v1/settings/{service}/{name}. +// Exactly one of String / Boolean / Integer must be supplied; the handler +// validates the type matches what's stored. Type is informational and +// optional — when omitted the handler infers from the supplied field. +type SettingPatchRequest struct { + Type string `json:"type,omitempty"` + String *string `json:"string,omitempty"` + Boolean *bool `json:"boolean,omitempty"` + Integer *int64 `json:"integer,omitempty"` +} + // TLSEnvironmentView is the low-privilege projection of an environment. // UserLevel operators (env scope) need basic env metadata so the SPA can // render its env switcher / dashboard / table chrome — but they MUST NOT