diff --git a/CHANGELOG.md b/CHANGELOG.md index 25139794..60fbd72b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented here. ## [Unreleased] +### Changed + +- **Store layer extraction**: All HTTP handlers now use store interfaces instead of embedding `*db.Pool` directly. Store interfaces are defined on the handler side and implemented in `internal/store/`, making handlers testable without a running database and centralizing SQL query knowledge. Affected handlers: auth, teams, analytics, executions, reports, admin, invitations, oauth. + +- **Bulk test result inserts**: Report ingestion now uses `pgx.Batch` to insert test results in bulk instead of one query per result. This eliminates the N+1 insert pattern that caused 1000+ round-trips for large reports. + ### Fixed - **IDOR vulnerability in invitation handlers**: `Create`, `List`, and `Revoke` invitation endpoints (`POST/GET/DELETE /api/v1/teams/{teamID}/invitations`) now verify that the authenticated user's team matches the URL `teamID` before checking role permissions. Previously, any maintainer or owner could list, create, or revoke invitations for any team regardless of membership. diff --git a/README.md b/README.md index efbda937..6fedffaf 100644 --- a/README.md +++ b/README.md @@ -497,7 +497,7 @@ internal/ db/ # Database pool, migrations handler/ # HTTP handlers (reports, executions, teams, admin, etc.) server/ # Router and middleware setup - store/ # Data access (audit, webhooks, quality gates) + store/ # Data access layer (store interfaces + implementations) github/ # GitHub commit status client llm/ # LLM provider abstraction (Anthropic, OpenAI, mock) mail/ # Email sender interface and SMTP implementation diff --git a/internal/handler/admin.go b/internal/handler/admin.go index 42ff6dae..80a82d54 100644 --- a/internal/handler/admin.go +++ b/internal/handler/admin.go @@ -5,15 +5,14 @@ import ( "time" "github.com/google/uuid" - "github.com/scaledtest/scaledtest/internal/db" - "github.com/scaledtest/scaledtest/internal/model" + "github.com/scaledtest/scaledtest/internal/store" ) // AdminHandler handles admin-only endpoints. type AdminHandler struct { AuditStore *store.AuditStore - DB *db.Pool + AdminStore adminStore } // ListAuditLog handles GET /api/v1/admin/audit-log. @@ -66,46 +65,18 @@ func (h *AdminHandler) ListAuditLog(w http.ResponseWriter, r *http.Request) { // ListUsers handles GET /api/v1/admin/users. func (h *AdminHandler) ListUsers(w http.ResponseWriter, r *http.Request) { - if h.DB == nil { + if h.AdminStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } limit, offset := parsePagination(r) - rows, err := h.DB.Query(r.Context(), - `SELECT id, email, display_name, role, created_at, updated_at - FROM users - ORDER BY created_at DESC - LIMIT $1 OFFSET $2`, - limit, offset) + users, total, err := h.AdminStore.ListUsers(r.Context(), limit, offset) if err != nil { Error(w, http.StatusInternalServerError, "failed to query users") return } - defer rows.Close() - - users := []model.User{} - for rows.Next() { - var u model.User - if err := rows.Scan(&u.ID, &u.Email, &u.DisplayName, &u.Role, &u.CreatedAt, &u.UpdatedAt); err != nil { - Error(w, http.StatusInternalServerError, "failed to scan user") - return - } - users = append(users, u) - } - if err := rows.Err(); err != nil { - Error(w, http.StatusInternalServerError, "failed to iterate users") - return - } - - var total int - err = h.DB.QueryRow(r.Context(), `SELECT COUNT(*) FROM users`).Scan(&total) - if err != nil { - Error(w, http.StatusInternalServerError, "failed to count users") - return - } - JSON(w, http.StatusOK, map[string]interface{}{ "users": users, "total": total, diff --git a/internal/handler/analytics.go b/internal/handler/analytics.go index 9f26e13d..65d9e242 100644 --- a/internal/handler/analytics.go +++ b/internal/handler/analytics.go @@ -9,12 +9,11 @@ import ( "github.com/scaledtest/scaledtest/internal/analytics" "github.com/scaledtest/scaledtest/internal/auth" - "github.com/scaledtest/scaledtest/internal/db" ) // AnalyticsHandler handles analytics endpoints. type AnalyticsHandler struct { - DB *db.Pool + AnalyticsStore analyticsStore } // Trends handles GET /api/v1/analytics/trends. @@ -26,57 +25,30 @@ func (h *AnalyticsHandler) Trends(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.AnalyticsStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } q := parseTrendQuery(r, claims.TeamID) - query := ` - SELECT - time_bucket($1::interval, created_at) AS bucket, - count(*) AS total, - count(*) FILTER (WHERE status = 'passed') AS passed, - count(*) FILTER (WHERE status = 'failed') AS failed, - count(*) FILTER (WHERE status = 'skipped') AS skipped - FROM test_results - WHERE team_id = $2 - AND created_at >= $3 - AND created_at <= $4 - GROUP BY bucket - ORDER BY bucket - ` - - rows, err := h.DB.Query(r.Context(), query, q.GroupBy, q.TeamID, q.StartDate, q.EndDate) + rows, err := h.AnalyticsStore.QueryTrends(r.Context(), q.GroupBy, q.TeamID, q.StartDate, q.EndDate) if err != nil { log.Error().Err(err).Msg("analytics: trends query failed") Error(w, http.StatusInternalServerError, "query failed") return } - defer rows.Close() - - var trends []analytics.TrendPoint - for rows.Next() { - var tp analytics.TrendPoint - if err := rows.Scan(&tp.Date, &tp.Total, &tp.Passed, &tp.Failed, &tp.Skipped); err != nil { - log.Error().Err(err).Msg("analytics: trends scan failed") - Error(w, http.StatusInternalServerError, "query failed") - return + trends := make([]analytics.TrendPoint, len(rows)) + for i, row := range rows { + trends[i] = analytics.TrendPoint{ + Date: row.Date, + Total: row.Total, + Passed: row.Passed, + Failed: row.Failed, + Skipped: row.Skipped, + PassRate: row.PassRate, } - tp.PassRate = analytics.ComputePassRate(tp.Passed, tp.Total) - trends = append(trends, tp) } - if err := rows.Err(); err != nil { - log.Error().Err(err).Msg("analytics: trends iteration failed") - Error(w, http.StatusInternalServerError, "query failed") - return - } - - if trends == nil { - trends = []analytics.TrendPoint{} - } - JSON(w, http.StatusOK, map[string]interface{}{ "trends": trends, }) @@ -91,7 +63,7 @@ func (h *AnalyticsHandler) FlakyTests(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.AnalyticsStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } @@ -99,56 +71,30 @@ func (h *AnalyticsHandler) FlakyTests(w http.ResponseWriter, r *http.Request) { q := parseFlakyQuery(r, claims.TeamID) cutoff := time.Now().Add(-q.Window) - // Query: for each test, get ordered statuses, then compute flakiness in Go. - // This avoids complex SQL window functions and uses the analytics.DetectFlaky helper. - query := ` - SELECT - name, - COALESCE(suite, '') AS suite, - COALESCE(file_path, '') AS file_path, - array_agg(status ORDER BY created_at) AS statuses, - (array_agg(status ORDER BY created_at DESC))[1] AS last_status, - count(*) AS total_runs - FROM test_results - WHERE team_id = $1 - AND created_at >= $2 - GROUP BY name, suite, file_path - HAVING count(*) >= $3 - ORDER BY name - ` - - rows, err := h.DB.Query(r.Context(), query, q.TeamID, cutoff, q.MinRuns) + rows, err := h.AnalyticsStore.QueryFlakyTests(r.Context(), claims.TeamID, cutoff, q.MinRuns) if err != nil { log.Error().Err(err).Msg("analytics: flaky query failed") Error(w, http.StatusInternalServerError, "query failed") return } - defer rows.Close() var flaky []analytics.FlakyTest - for rows.Next() { - var ft analytics.FlakyTest - var statuses []string - if err := rows.Scan(&ft.Name, &ft.Suite, &ft.FilePath, &statuses, &ft.LastStatus, &ft.TotalRuns); err != nil { - log.Error().Err(err).Msg("analytics: flaky scan failed") - Error(w, http.StatusInternalServerError, "query failed") - return + for _, fr := range rows { + ft := analytics.FlakyTest{ + Name: fr.Name, + Suite: fr.Suite, + FilePath: fr.FilePath, + LastStatus: fr.LastStatus, + TotalRuns: fr.TotalRuns, } - - ft.FlipCount, ft.FlipRate = analytics.DetectFlaky(statuses) + ft.FlipCount, ft.FlipRate = analytics.DetectFlaky(fr.Statuses) if ft.FlipCount > 0 { flaky = append(flaky, ft) } - if q.Limit > 0 && len(flaky) >= q.Limit { break } } - if err := rows.Err(); err != nil { - log.Error().Err(err).Msg("analytics: flaky iteration failed") - Error(w, http.StatusInternalServerError, "query failed") - return - } if flaky == nil { flaky = []analytics.FlakyTest{} @@ -168,7 +114,7 @@ func (h *AnalyticsHandler) ErrorAnalysis(w http.ResponseWriter, r *http.Request) return } - if h.DB == nil { + if h.AnalyticsStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } @@ -176,55 +122,26 @@ func (h *AnalyticsHandler) ErrorAnalysis(w http.ResponseWriter, r *http.Request) start, end := parseDateRange(r) limit := parseIntParam(r, "limit", 20) - query := ` - SELECT - message, - count(*) AS count, - array_agg(DISTINCT name) AS test_names, - min(created_at) AS first_seen, - max(created_at) AS last_seen - FROM test_results - WHERE team_id = $1 - AND status = 'failed' - AND message IS NOT NULL - AND message != '' - AND created_at >= $2 - AND created_at <= $3 - GROUP BY message - ORDER BY count DESC - LIMIT $4 - ` - - rows, err := h.DB.Query(r.Context(), query, claims.TeamID, start, end, limit) + clusters, err := h.AnalyticsStore.QueryErrorClusters(r.Context(), claims.TeamID, start, end, limit) if err != nil { log.Error().Err(err).Msg("analytics: error analysis query failed") Error(w, http.StatusInternalServerError, "query failed") return } - defer rows.Close() - - var clusters []analytics.ErrorCluster - for rows.Next() { - var ec analytics.ErrorCluster - if err := rows.Scan(&ec.Message, &ec.Count, &ec.TestNames, &ec.FirstSeen, &ec.LastSeen); err != nil { - log.Error().Err(err).Msg("analytics: error analysis scan failed") - Error(w, http.StatusInternalServerError, "query failed") - return - } - clusters = append(clusters, ec) - } - if err := rows.Err(); err != nil { - log.Error().Err(err).Msg("analytics: error analysis iteration failed") - Error(w, http.StatusInternalServerError, "query failed") - return - } - if clusters == nil { - clusters = []analytics.ErrorCluster{} + result := make([]analytics.ErrorCluster, len(clusters)) + for i, ec := range clusters { + result[i] = analytics.ErrorCluster{ + Message: ec.Message, + Count: ec.Count, + TestNames: ec.TestNames, + FirstSeen: ec.FirstSeen, + LastSeen: ec.LastSeen, + } } JSON(w, http.StatusOK, map[string]interface{}{ - "errors": clusters, + "errors": result, }) } @@ -237,46 +154,24 @@ func (h *AnalyticsHandler) DurationDistribution(w http.ResponseWriter, r *http.R return } - if h.DB == nil { + if h.AnalyticsStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } start, end := parseDateRange(r) - // Build histogram buckets buckets := analytics.DefaultDurationBuckets() - bucketQuery := ` - SELECT duration_ms - FROM test_results - WHERE team_id = $1 - AND created_at >= $2 - AND created_at <= $3 - ` - - rows, err := h.DB.Query(r.Context(), bucketQuery, claims.TeamID, start, end) + durations, err := h.AnalyticsStore.QueryDurationBuckets(r.Context(), claims.TeamID, start, end) if err != nil { log.Error().Err(err).Msg("analytics: duration bucket query failed") Error(w, http.StatusInternalServerError, "query failed") return } - defer rows.Close() - - for rows.Next() { - var ms int64 - if err := rows.Scan(&ms); err != nil { - log.Error().Err(err).Msg("analytics: duration bucket scan failed") - Error(w, http.StatusInternalServerError, "query failed") - return - } + for _, ms := range durations { idx := analytics.BucketDuration(ms, buckets) buckets[idx].Count++ } - if err := rows.Err(); err != nil { - log.Error().Err(err).Msg("analytics: duration bucket iteration failed") - Error(w, http.StatusInternalServerError, "query failed") - return - } JSON(w, http.StatusOK, map[string]interface{}{ "distribution": buckets, diff --git a/internal/handler/analytics_test.go b/internal/handler/analytics_test.go index 6baaf6ba..faf57de0 100644 --- a/internal/handler/analytics_test.go +++ b/internal/handler/analytics_test.go @@ -7,9 +7,8 @@ import ( "time" ) - func TestAnalyticsTrendsNoDB(t *testing.T) { - h := &AnalyticsHandler{DB: nil} + h := &AnalyticsHandler{AnalyticsStore: nil} req := httptest.NewRequest("GET", "/api/v1/analytics/trends", nil) req = testWithClaimsTeamOnly(req, "team-1") @@ -23,7 +22,7 @@ func TestAnalyticsTrendsNoDB(t *testing.T) { } func TestAnalyticsFlakyTestsNoDB(t *testing.T) { - h := &AnalyticsHandler{DB: nil} + h := &AnalyticsHandler{AnalyticsStore: nil} req := httptest.NewRequest("GET", "/api/v1/analytics/flaky-tests", nil) req = testWithClaimsTeamOnly(req, "team-1") @@ -37,7 +36,7 @@ func TestAnalyticsFlakyTestsNoDB(t *testing.T) { } func TestAnalyticsErrorAnalysisNoDB(t *testing.T) { - h := &AnalyticsHandler{DB: nil} + h := &AnalyticsHandler{AnalyticsStore: nil} req := httptest.NewRequest("GET", "/api/v1/analytics/error-analysis", nil) req = testWithClaimsTeamOnly(req, "team-1") @@ -51,7 +50,7 @@ func TestAnalyticsErrorAnalysisNoDB(t *testing.T) { } func TestAnalyticsDurationDistributionNoDB(t *testing.T) { - h := &AnalyticsHandler{DB: nil} + h := &AnalyticsHandler{AnalyticsStore: nil} req := httptest.NewRequest("GET", "/api/v1/analytics/duration-distribution", nil) req = testWithClaimsTeamOnly(req, "team-1") @@ -65,7 +64,7 @@ func TestAnalyticsDurationDistributionNoDB(t *testing.T) { } func TestAnalyticsUnauthorized(t *testing.T) { - h := &AnalyticsHandler{DB: nil} + h := &AnalyticsHandler{AnalyticsStore: nil} handlers := []struct { name string diff --git a/internal/handler/auth.go b/internal/handler/auth.go index afd0048c..c697cde1 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -15,17 +15,10 @@ import ( const refreshTokenCookie = "refresh_token" -// authDB is the minimal database interface used by AuthHandler. -// *pgxpool.Pool satisfies this interface. -type authDB interface { - QueryRow(ctx context.Context, sql string, args ...any) pgx.Row - Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) -} - // AuthHandler handles authentication endpoints. type AuthHandler struct { - JWT *auth.JWTManager - DB authDB + JWT *auth.JWTManager + AuthStore authStore } // RegisterRequest is the request body for user registration. @@ -62,11 +55,6 @@ type ChangePasswordRequest struct { NewPassword string `json:"new_password" validate:"required,min=8,max=72"` } -// UpdateProfileRequest is the request body for updating the authenticated user's profile. -type UpdateProfileRequest struct { - DisplayName string `json:"display_name" validate:"required,min=1"` -} - // ChangePassword handles POST /api/v1/auth/change-password. func (h *AuthHandler) ChangePassword(w http.ResponseWriter, r *http.Request) { claims := auth.GetClaims(r.Context()) @@ -81,17 +69,12 @@ func (h *AuthHandler) ChangePassword(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.AuthStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } - // Look up current password hash - var passwordHash string - err := h.DB.QueryRow(r.Context(), - "SELECT password_hash FROM users WHERE id = $1", - claims.UserID, - ).Scan(&passwordHash) + user, err := h.AuthStore.GetUserByID(r.Context(), claims.UserID) if err == pgx.ErrNoRows { Error(w, http.StatusUnauthorized, "unauthorized") return @@ -101,29 +84,23 @@ func (h *AuthHandler) ChangePassword(w http.ResponseWriter, r *http.Request) { return } - // Verify current password - if !auth.CheckPassword(req.CurrentPassword, passwordHash) { + if !auth.CheckPassword(req.CurrentPassword, user.PasswordHash) { Error(w, http.StatusUnauthorized, "invalid current password") return } - // Hash new password newHash, err := auth.HashPassword(req.NewPassword) if err != nil { Error(w, http.StatusInternalServerError, "internal error") return } - // Update password - tag, err := h.DB.Exec(r.Context(), - "UPDATE users SET password_hash = $1 WHERE id = $2", - newHash, claims.UserID, - ) + rowsAffected, err := h.AuthStore.UpdatePassword(r.Context(), claims.UserID, newHash) if err != nil { Error(w, http.StatusInternalServerError, "internal error") return } - if tag.RowsAffected() != 1 { + if rowsAffected != 1 { Error(w, http.StatusInternalServerError, "internal error") return } @@ -150,18 +127,12 @@ func (h *AuthHandler) UpdateMe(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.AuthStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } - var userID, email, displayName, role string - err := h.DB.QueryRow(r.Context(), - `UPDATE users SET display_name = $1, updated_at = now() - WHERE id = $2 - RETURNING id, email, display_name, role`, - req.DisplayName, claims.UserID, - ).Scan(&userID, &email, &displayName, &role) + user, err := h.AuthStore.UpdateProfile(r.Context(), claims.UserID, req.DisplayName) if err == pgx.ErrNoRows { Error(w, http.StatusNotFound, "user not found") return @@ -172,10 +143,10 @@ func (h *AuthHandler) UpdateMe(w http.ResponseWriter, r *http.Request) { } JSON(w, http.StatusOK, UserResponse{ - ID: userID, - Email: email, - DisplayName: displayName, - Role: role, + ID: user.ID, + Email: user.Email, + DisplayName: user.DisplayName, + Role: user.Role, }) } @@ -187,16 +158,12 @@ func (h *AuthHandler) GetMe(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.AuthStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } - var email, displayName, role string - err := h.DB.QueryRow(r.Context(), - "SELECT email, display_name, role FROM users WHERE id = $1", - claims.UserID, - ).Scan(&email, &displayName, &role) + user, err := h.AuthStore.GetUserByID(r.Context(), claims.UserID) if err == pgx.ErrNoRows { Error(w, http.StatusUnauthorized, "unauthorized") return @@ -207,16 +174,16 @@ func (h *AuthHandler) GetMe(w http.ResponseWriter, r *http.Request) { } JSON(w, http.StatusOK, UserResponse{ - ID: claims.UserID, - Email: email, - DisplayName: displayName, - Role: role, + ID: user.ID, + Email: user.Email, + DisplayName: user.DisplayName, + Role: user.Role, }) } // Register handles POST /auth/register. func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { - if h.DB == nil { + if h.AuthStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } @@ -227,10 +194,7 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { return } - // Check if email already exists - var exists bool - err := h.DB.QueryRow(r.Context(), - "SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)", req.Email).Scan(&exists) + exists, err := h.AuthStore.EmailExists(r.Context(), req.Email) if err != nil { Error(w, http.StatusInternalServerError, "internal error") return @@ -240,7 +204,6 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { return } - // Hash password hash, err := auth.HashPassword(req.Password) if err != nil { Error(w, http.StatusInternalServerError, "internal error") @@ -255,32 +218,22 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { // second INSERT will violate the index (SQLSTATE 23505). In that case we // retry explicitly as 'maintainer', which is correct because a committed // owner row now exists. - var userID, role string - err = h.DB.QueryRow(r.Context(), - `INSERT INTO users (email, password_hash, display_name, role) - SELECT $1, $2, $3, - CASE WHEN NOT EXISTS (SELECT 1 FROM users) THEN 'owner'::text ELSE 'maintainer'::text END - RETURNING id, role`, - req.Email, hash, req.DisplayName, - ).Scan(&userID, &role) + userID, role, err := h.AuthStore.CreateUser(r.Context(), req.Email, hash, req.DisplayName, "") if err != nil { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == "23505" && pgErr.ConstraintName == "idx_users_single_owner" { - // A concurrent registration claimed the owner role; retry as maintainer. - err = h.DB.QueryRow(r.Context(), - `INSERT INTO users (email, password_hash, display_name, role) - VALUES ($1, $2, $3, 'maintainer') - RETURNING id, role`, - req.Email, hash, req.DisplayName, - ).Scan(&userID, &role) - } - if err != nil { + userID, err = h.AuthStore.CreateUserWithRole(r.Context(), req.Email, hash, req.DisplayName, "maintainer") + if err != nil { + Error(w, http.StatusInternalServerError, "internal error") + return + } + role = "maintainer" + } else { Error(w, http.StatusInternalServerError, "internal error") return } } - // Generate token pair and create session resp, err := h.issueTokens(r.Context(), w, r, userID, req.Email, role, "") if err != nil { Error(w, http.StatusInternalServerError, "internal error") @@ -299,7 +252,7 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { // Login handles POST /auth/login. func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { - if h.DB == nil { + if h.AuthStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } @@ -310,12 +263,7 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { return } - // Look up user by email - var userID, passwordHash, displayName, role string - err := h.DB.QueryRow(r.Context(), - `SELECT id, password_hash, display_name, role FROM users WHERE email = $1`, - req.Email, - ).Scan(&userID, &passwordHash, &displayName, &role) + user, err := h.AuthStore.GetUserByEmail(r.Context(), req.Email) if err == pgx.ErrNoRows { Error(w, http.StatusUnauthorized, "invalid credentials") return @@ -325,8 +273,7 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { return } - // Verify password - if !auth.CheckPassword(req.Password, passwordHash) { + if !auth.CheckPassword(req.Password, user.PasswordHash) { Error(w, http.StatusUnauthorized, "invalid credentials") return } @@ -335,23 +282,19 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { // have team context for team-scoped API calls. Best-effort: if the // lookup fails for any reason, teamID stays empty. var teamID string - _ = h.DB.QueryRow(r.Context(), - `SELECT team_id FROM user_teams WHERE user_id = $1 ORDER BY joined_at ASC LIMIT 1`, - userID, - ).Scan(&teamID) + teamID, _ = h.AuthStore.GetPrimaryTeamID(r.Context(), user.ID) - // Generate token pair and create session - resp, err := h.issueTokens(r.Context(), w, r, userID, req.Email, role, teamID) + resp, err := h.issueTokens(r.Context(), w, r, user.ID, user.Email, user.Role, teamID) if err != nil { Error(w, http.StatusInternalServerError, "internal error") return } resp.User = UserResponse{ - ID: userID, - Email: req.Email, - DisplayName: displayName, - Role: role, + ID: user.ID, + Email: user.Email, + DisplayName: user.DisplayName, + Role: user.Role, } JSON(w, http.StatusOK, resp) @@ -359,7 +302,7 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { // Refresh handles POST /auth/refresh. func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) { - if h.DB == nil { + if h.AuthStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } @@ -370,15 +313,7 @@ func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) { return } - // Look up session by refresh token - var sessionID, userID string - var expiresAt time.Time - err = h.DB.QueryRow(r.Context(), - `SELECT s.id, s.user_id, s.expires_at - FROM sessions s - WHERE s.refresh_token = $1`, - cookie.Value, - ).Scan(&sessionID, &userID, &expiresAt) + session, err := h.AuthStore.GetSessionByRefreshToken(r.Context(), cookie.Value) if err == pgx.ErrNoRows { Error(w, http.StatusUnauthorized, "invalid refresh token") return @@ -388,39 +323,32 @@ func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) { return } - if time.Now().After(expiresAt) { - // Delete expired session - _, _ = h.DB.Exec(r.Context(), "DELETE FROM sessions WHERE id = $1", sessionID) + if time.Now().After(session.ExpiresAt) { + _ = h.AuthStore.DeleteSession(r.Context(), session.ID) clearRefreshCookie(w, r) Error(w, http.StatusUnauthorized, "refresh token expired") return } - // Look up user - var email, displayName, role string - err = h.DB.QueryRow(r.Context(), - `SELECT email, display_name, role FROM users WHERE id = $1`, userID, - ).Scan(&email, &displayName, &role) + user, err := h.AuthStore.GetUserByID(r.Context(), session.UserID) if err != nil { Error(w, http.StatusInternalServerError, "internal error") return } - // Delete old session (rotate refresh token) - _, _ = h.DB.Exec(r.Context(), "DELETE FROM sessions WHERE id = $1", sessionID) + _ = h.AuthStore.DeleteSession(r.Context(), session.ID) - // Issue new token pair - resp, err := h.issueTokens(r.Context(), w, r, userID, email, role, "") + resp, err := h.issueTokens(r.Context(), w, r, user.ID, user.Email, user.Role, "") if err != nil { Error(w, http.StatusInternalServerError, "internal error") return } resp.User = UserResponse{ - ID: userID, - Email: email, - DisplayName: displayName, - Role: role, + ID: user.ID, + Email: user.Email, + DisplayName: user.DisplayName, + Role: user.Role, } JSON(w, http.StatusOK, resp) @@ -428,20 +356,18 @@ func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) { // Logout handles POST /auth/logout. func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) { - if h.DB == nil { + if h.AuthStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } cookie, err := r.Cookie(refreshTokenCookie) if err != nil { - // No cookie — already logged out, just return success JSON(w, http.StatusOK, map[string]string{"message": "logged out"}) return } - // Delete session by refresh token - _, _ = h.DB.Exec(r.Context(), "DELETE FROM sessions WHERE refresh_token = $1", cookie.Value) + _ = h.AuthStore.DeleteSessionByRefreshToken(r.Context(), cookie.Value) clearRefreshCookie(w, r) JSON(w, http.StatusOK, map[string]string{"message": "logged out"}) @@ -454,10 +380,8 @@ func (h *AuthHandler) issueTokens(ctx context.Context, w http.ResponseWriter, r return nil, err } - // Extract client metadata userAgent := r.UserAgent() ipAddr := net.ParseIP(r.RemoteAddr) - // RemoteAddr may include port — try to parse host only if ipAddr == nil { host, _, _ := net.SplitHostPort(r.RemoteAddr) ipAddr = net.ParseIP(host) @@ -465,12 +389,7 @@ func (h *AuthHandler) issueTokens(ctx context.Context, w http.ResponseWriter, r expiresAt := time.Now().Add(h.JWT.RefreshDuration()) - _, err = h.DB.Exec(ctx, - `INSERT INTO sessions (user_id, refresh_token, user_agent, ip_address, expires_at) - VALUES ($1, $2, $3, $4, $5)`, - userID, pair.RefreshToken, userAgent, ipAddr, expiresAt, - ) - if err != nil { + if err := h.AuthStore.CreateSession(ctx, userID, pair.RefreshToken, userAgent, ipAddr, expiresAt); err != nil { return nil, err } diff --git a/internal/handler/auth_integration_test.go b/internal/handler/auth_integration_test.go index 846fbdb7..470fc2e8 100644 --- a/internal/handler/auth_integration_test.go +++ b/internal/handler/auth_integration_test.go @@ -15,12 +15,13 @@ import ( "github.com/scaledtest/scaledtest/internal/auth" "github.com/scaledtest/scaledtest/internal/integration" + "github.com/scaledtest/scaledtest/internal/store" ) // newIntegrationAuthHandler returns an AuthHandler backed by the given real DB pool. func newIntegrationAuthHandler(tdb *integration.TestDB) *AuthHandler { jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - return &AuthHandler{JWT: jwt, DB: tdb.Pool} + return &AuthHandler{JWT: jwt, AuthStore: store.NewAuthStore(tdb.Pool)} } // registerViaHandler calls the Register handler with the given email and returns diff --git a/internal/handler/auth_test.go b/internal/handler/auth_test.go index b8d727f0..29d2c9c8 100644 --- a/internal/handler/auth_test.go +++ b/internal/handler/auth_test.go @@ -3,6 +3,7 @@ package handler import ( "context" "encoding/json" + "net" "net/http" "net/http/httptest" "strings" @@ -13,36 +14,68 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/scaledtest/scaledtest/internal/auth" + "github.com/scaledtest/scaledtest/internal/model" + "github.com/scaledtest/scaledtest/internal/store" ) -// mockAuthDB implements authDB for testing. -type mockAuthDB struct { - queryRowFn func(ctx context.Context, sql string, args ...any) pgx.Row - execFn func(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) -} +const testSecret = "test-secret-32-chars-long-enough!" -func (m *mockAuthDB) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { - return m.queryRowFn(ctx, sql, args...) +func newTestAuthHandler() *AuthHandler { + jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) + return &AuthHandler{JWT: jwt, AuthStore: nil} } -func (m *mockAuthDB) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { - return m.execFn(ctx, sql, arguments...) +// mockAuthStore implements authStore for testing. +type mockAuthStore struct { + getUserByEmailFn func(ctx context.Context, email string) (*model.User, error) + getUserByIDFn func(ctx context.Context, id string) (*model.User, error) + emailExistsFn func(ctx context.Context, email string) (bool, error) + createUserFn func(ctx context.Context, email, passwordHash, displayName, role string) (string, string, error) + createUserWithRoleFn func(ctx context.Context, email, passwordHash, displayName, role string) (string, error) + updatePasswordFn func(ctx context.Context, userID, passwordHash string) (int64, error) + updateProfileFn func(ctx context.Context, userID, displayName string) (*model.User, error) + getPrimaryTeamIDFn func(ctx context.Context, userID string) (string, error) + createSessionFn func(ctx context.Context, userID, refreshToken, userAgent string, ipAddr net.IP, expiresAt time.Time) error + getSessionByRefreshTokenFn func(ctx context.Context, refreshToken string) (*store.SessionInfo, error) + deleteSessionFn func(ctx context.Context, sessionID string) error + deleteSessionByRefreshFn func(ctx context.Context, refreshToken string) error } -// mockRow implements pgx.Row for testing. -type mockRow struct { - scanFn func(dest ...any) error +func (m *mockAuthStore) GetUserByEmail(ctx context.Context, email string) (*model.User, error) { + return m.getUserByEmailFn(ctx, email) } - -func (r *mockRow) Scan(dest ...any) error { - return r.scanFn(dest...) +func (m *mockAuthStore) GetUserByID(ctx context.Context, id string) (*model.User, error) { + return m.getUserByIDFn(ctx, id) } - -const testSecret = "test-secret-32-chars-long-enough!" - -func newTestAuthHandler() *AuthHandler { - jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - return &AuthHandler{JWT: jwt, DB: nil} +func (m *mockAuthStore) EmailExists(ctx context.Context, email string) (bool, error) { + return m.emailExistsFn(ctx, email) +} +func (m *mockAuthStore) CreateUser(ctx context.Context, email, passwordHash, displayName, role string) (string, string, error) { + return m.createUserFn(ctx, email, passwordHash, displayName, role) +} +func (m *mockAuthStore) CreateUserWithRole(ctx context.Context, email, passwordHash, displayName, role string) (string, error) { + return m.createUserWithRoleFn(ctx, email, passwordHash, displayName, role) +} +func (m *mockAuthStore) UpdatePassword(ctx context.Context, userID, passwordHash string) (int64, error) { + return m.updatePasswordFn(ctx, userID, passwordHash) +} +func (m *mockAuthStore) UpdateProfile(ctx context.Context, userID, displayName string) (*model.User, error) { + return m.updateProfileFn(ctx, userID, displayName) +} +func (m *mockAuthStore) GetPrimaryTeamID(ctx context.Context, userID string) (string, error) { + return m.getPrimaryTeamIDFn(ctx, userID) +} +func (m *mockAuthStore) CreateSession(ctx context.Context, userID, refreshToken, userAgent string, ipAddr net.IP, expiresAt time.Time) error { + return m.createSessionFn(ctx, userID, refreshToken, userAgent, ipAddr, expiresAt) +} +func (m *mockAuthStore) GetSessionByRefreshToken(ctx context.Context, refreshToken string) (*store.SessionInfo, error) { + return m.getSessionByRefreshTokenFn(ctx, refreshToken) +} +func (m *mockAuthStore) DeleteSession(ctx context.Context, sessionID string) error { + return m.deleteSessionFn(ctx, sessionID) +} +func (m *mockAuthStore) DeleteSessionByRefreshToken(ctx context.Context, refreshToken string) error { + return m.deleteSessionByRefreshFn(ctx, refreshToken) } func TestRegisterNoDB(t *testing.T) { @@ -124,8 +157,6 @@ func TestRegisterInvalidRequest(t *testing.T) { h.Register(w, req) - // Without DB, we get 503 (DB check happens first). With DB, bad input gets 400. - // Either way, it should NOT be 200/201. if w.Code == http.StatusCreated || w.Code == http.StatusOK { t.Errorf("Register(%s): should not succeed, got %d", tt.name, w.Code) } @@ -161,9 +192,6 @@ func TestLoginInvalidRequest(t *testing.T) { } func TestRefreshMissingCookie(t *testing.T) { - // Need a handler with a real (mock) DB for this test to get past the nil check. - // With nil DB, we get 503, which is correct but tests a different path. - // This test verifies the no-DB path returns 503. h := newTestAuthHandler() req := httptest.NewRequest("POST", "/auth/refresh", nil) @@ -177,7 +205,6 @@ func TestRefreshMissingCookie(t *testing.T) { } func TestLogoutNoCookie(t *testing.T) { - // Without DB, returns 503 h := newTestAuthHandler() req := httptest.NewRequest("POST", "/auth/logout", nil) @@ -191,7 +218,6 @@ func TestLogoutNoCookie(t *testing.T) { } func TestAuthResponseShape(t *testing.T) { - // Verify the JSON shape of AuthResponse resp := AuthResponse{ User: UserResponse{ ID: "user-1", @@ -213,14 +239,12 @@ func TestAuthResponseShape(t *testing.T) { t.Fatalf("unmarshal: %v", err) } - // Check top-level keys for _, key := range []string{"user", "access_token", "expires_at"} { if _, ok := decoded[key]; !ok { t.Errorf("missing key %q in AuthResponse", key) } } - // Check user keys user := decoded["user"].(map[string]interface{}) for _, key := range []string{"id", "email", "display_name", "role"} { if _, ok := user[key]; !ok { @@ -231,7 +255,6 @@ func TestAuthResponseShape(t *testing.T) { func TestRefreshCookieAttributes(t *testing.T) { w := httptest.NewRecorder() - // Simulate HTTPS via X-Forwarded-Proto so Secure flag is set req := httptest.NewRequest("POST", "/auth/refresh", nil) req.Header.Set("X-Forwarded-Proto", "https") setRefreshCookie(w, req, "test-token", 7*24*time.Hour) @@ -264,7 +287,6 @@ func TestRefreshCookieAttributes(t *testing.T) { func TestRefreshCookieNotSecureOverHTTP(t *testing.T) { w := httptest.NewRecorder() - // Plain HTTP request — Secure should be false req := httptest.NewRequest("POST", "/auth/refresh", nil) setRefreshCookie(w, req, "test-token", 7*24*time.Hour) @@ -302,7 +324,6 @@ func TestChangePasswordNoAuth(t *testing.T) { body := `{"current_password":"oldpassword123","new_password":"newpassword123"}` req := httptest.NewRequest("POST", "/api/v1/auth/change-password", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") - // No claims injected into context w := httptest.NewRecorder() h.ChangePassword(w, req) @@ -334,20 +355,20 @@ func TestChangePasswordSuccess(t *testing.T) { t.Fatalf("HashPassword: %v", err) } - mockDB := &mockAuthDB{ - queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*string)) = hash - return nil - }} + ms := &mockAuthStore{ + getUserByIDFn: func(_ context.Context, id string) (*model.User, error) { + return &model.User{ID: id, PasswordHash: hash}, nil + }, + updatePasswordFn: func(_ context.Context, _, _ string) (int64, error) { + return 1, nil }, - execFn: func(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { - return pgconn.NewCommandTag("UPDATE 1"), nil + createSessionFn: func(_ context.Context, _ string, _ string, _ string, _ net.IP, _ time.Time) error { + return nil }, } jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - h := &AuthHandler{JWT: jwt, DB: mockDB} + h := &AuthHandler{JWT: jwt, AuthStore: ms} body := `{"current_password":"oldpassword123","new_password":"newpassword123"}` req := httptest.NewRequest("POST", "/api/v1/auth/change-password", strings.NewReader(body)) @@ -368,17 +389,14 @@ func TestChangePasswordWrongCurrentPassword(t *testing.T) { t.Fatalf("HashPassword: %v", err) } - mockDB := &mockAuthDB{ - queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*string)) = hash - return nil - }} + ms := &mockAuthStore{ + getUserByIDFn: func(_ context.Context, id string) (*model.User, error) { + return &model.User{ID: id, PasswordHash: hash}, nil }, } jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - h := &AuthHandler{JWT: jwt, DB: mockDB} + h := &AuthHandler{JWT: jwt, AuthStore: ms} body := `{"current_password":"wrongpassword","new_password":"newpassword123"}` req := httptest.NewRequest("POST", "/api/v1/auth/change-password", strings.NewReader(body)) @@ -399,20 +417,17 @@ func TestChangePasswordRowsAffectedZero(t *testing.T) { t.Fatalf("HashPassword: %v", err) } - mockDB := &mockAuthDB{ - queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*string)) = hash - return nil - }} + ms := &mockAuthStore{ + getUserByIDFn: func(_ context.Context, id string) (*model.User, error) { + return &model.User{ID: id, PasswordHash: hash}, nil }, - execFn: func(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { - return pgconn.NewCommandTag("UPDATE 0"), nil + updatePasswordFn: func(_ context.Context, _, _ string) (int64, error) { + return 0, nil }, } jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - h := &AuthHandler{JWT: jwt, DB: mockDB} + h := &AuthHandler{JWT: jwt, AuthStore: ms} body := `{"current_password":"oldpassword123","new_password":"newpassword123"}` req := httptest.NewRequest("POST", "/api/v1/auth/change-password", strings.NewReader(body)) @@ -428,8 +443,6 @@ func TestChangePasswordRowsAffectedZero(t *testing.T) { } func TestChangePasswordInvalidRequest(t *testing.T) { - // Decode/validate runs before the DB nil check, so invalid requests must - // produce 400 Bad Request even when h.DB is nil. h := newTestAuthHandler() tests := []struct { @@ -494,39 +507,27 @@ func TestLoginEmbedsPrimaryTeamInJWT(t *testing.T) { } const teamID = "550e8400-e29b-41d4-a716-446655440000" - callCount := 0 - mockDB := &mockAuthDB{ - queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { - callCount++ - switch callCount { - case 1: - // User lookup: id, password_hash, display_name, role - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*string)) = "user-id-1" - *(dest[1].(*string)) = hash - *(dest[2].(*string)) = "Test User" - *(dest[3].(*string)) = "maintainer" - return nil - }} - case 2: - // Primary team lookup - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*string)) = teamID - return nil - }} - default: - t.Errorf("unexpected QueryRow call #%d", callCount) - return &mockRow{scanFn: func(dest ...any) error { return pgx.ErrNoRows }} - } + ms := &mockAuthStore{ + getUserByEmailFn: func(_ context.Context, email string) (*model.User, error) { + return &model.User{ + ID: "user-id-1", + Email: email, + PasswordHash: hash, + DisplayName: "Test User", + Role: "maintainer", + }, nil }, - execFn: func(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { - return pgconn.NewCommandTag("INSERT 1"), nil + getPrimaryTeamIDFn: func(_ context.Context, _ string) (string, error) { + return teamID, nil + }, + createSessionFn: func(_ context.Context, _ string, _ string, _ string, _ net.IP, _ time.Time) error { + return nil }, } jwtMgr := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - h := &AuthHandler{JWT: jwtMgr, DB: mockDB} + h := &AuthHandler{JWT: jwtMgr, AuthStore: ms} body := `{"email":"test@example.com","password":"password123"}` req := httptest.NewRequest("POST", "/auth/login", strings.NewReader(body)) @@ -565,38 +566,26 @@ func TestLoginNoTeamHasEmptyTeamIDInJWT(t *testing.T) { t.Fatalf("HashPassword: %v", err) } - callCount := 0 - - mockDB := &mockAuthDB{ - queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { - callCount++ - switch callCount { - case 1: - // User lookup - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*string)) = "user-id-1" - *(dest[1].(*string)) = hash - *(dest[2].(*string)) = "Test User" - *(dest[3].(*string)) = "maintainer" - return nil - }} - case 2: - // No teams found - return &mockRow{scanFn: func(dest ...any) error { - return pgx.ErrNoRows - }} - default: - t.Errorf("unexpected QueryRow call #%d", callCount) - return &mockRow{scanFn: func(dest ...any) error { return pgx.ErrNoRows }} - } + ms := &mockAuthStore{ + getUserByEmailFn: func(_ context.Context, email string) (*model.User, error) { + return &model.User{ + ID: "user-id-1", + Email: email, + PasswordHash: hash, + DisplayName: "Test User", + Role: "maintainer", + }, nil + }, + getPrimaryTeamIDFn: func(_ context.Context, _ string) (string, error) { + return "", pgx.ErrNoRows }, - execFn: func(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { - return pgconn.NewCommandTag("INSERT 1"), nil + createSessionFn: func(_ context.Context, _ string, _ string, _ string, _ net.IP, _ time.Time) error { + return nil }, } jwtMgr := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - h := &AuthHandler{JWT: jwtMgr, DB: mockDB} + h := &AuthHandler{JWT: jwtMgr, AuthStore: ms} body := `{"email":"test@example.com","password":"password123"}` req := httptest.NewRequest("POST", "/auth/login", strings.NewReader(body)) @@ -629,19 +618,19 @@ func TestLoginNoTeamHasEmptyTeamIDInJWT(t *testing.T) { } func TestGetMeSuccess(t *testing.T) { - mockDB := &mockAuthDB{ - queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*string)) = "test@test.com" - *(dest[1].(*string)) = "Test User" - *(dest[2].(*string)) = "maintainer" - return nil - }} + ms := &mockAuthStore{ + getUserByIDFn: func(_ context.Context, id string) (*model.User, error) { + return &model.User{ + ID: id, + Email: "test@test.com", + DisplayName: "Test User", + Role: "maintainer", + }, nil }, } jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - h := &AuthHandler{JWT: jwt, DB: mockDB} + h := &AuthHandler{JWT: jwt, AuthStore: ms} req := httptest.NewRequest("GET", "/api/v1/auth/me", nil) req = req.WithContext(auth.SetClaims(req.Context(), &auth.Claims{UserID: "user-123"})) @@ -669,16 +658,14 @@ func TestGetMeSuccess(t *testing.T) { } func TestGetMeUserNotFound(t *testing.T) { - mockDB := &mockAuthDB{ - queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { - return &mockRow{scanFn: func(dest ...any) error { - return pgx.ErrNoRows - }} + ms := &mockAuthStore{ + getUserByIDFn: func(_ context.Context, _ string) (*model.User, error) { + return nil, pgx.ErrNoRows }, } jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - h := &AuthHandler{JWT: jwt, DB: mockDB} + h := &AuthHandler{JWT: jwt, AuthStore: ms} req := httptest.NewRequest("GET", "/api/v1/auth/me", nil) req = req.WithContext(auth.SetClaims(req.Context(), &auth.Claims{UserID: "user-123"})) @@ -691,45 +678,29 @@ func TestGetMeUserNotFound(t *testing.T) { } } -// TestRegister_WhenOwnerConstraintViolated_RetriesAsMaintainer verifies that when -// the idx_users_single_owner unique partial index raises a 23505 violation (two -// concurrent registrations both evaluating as the first user), the handler retries -// with role='maintainer' and succeeds. func TestRegister_WhenOwnerConstraintViolated_RetriesAsMaintainer(t *testing.T) { callCount := 0 - mockDB := &mockAuthDB{ - queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { + ms := &mockAuthStore{ + emailExistsFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + createUserFn: func(_ context.Context, _, _, _, _ string) (string, string, error) { callCount++ - switch callCount { - case 1: - // Email existence check - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*bool)) = false - return nil - }} - case 2: - // First INSERT attempt — simulate owner constraint violation - return &mockRow{scanFn: func(dest ...any) error { - return &pgconn.PgError{Code: "23505", ConstraintName: "idx_users_single_owner"} - }} - case 3: - // Retry INSERT as maintainer — succeeds - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*string)) = "user-uuid-1" - *(dest[1].(*string)) = "maintainer" - return nil - }} - default: - return &mockRow{scanFn: func(dest ...any) error { return pgx.ErrNoRows }} + if callCount == 1 { + return "", "", &pgconn.PgError{Code: "23505", ConstraintName: "idx_users_single_owner"} } + return "", "owner", nil }, - execFn: func(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { - return pgconn.NewCommandTag("INSERT 1"), nil + createUserWithRoleFn: func(_ context.Context, _, _, _, _ string) (string, error) { + return "user-uuid-1", nil + }, + createSessionFn: func(_ context.Context, _ string, _ string, _ string, _ net.IP, _ time.Time) error { + return nil }, } jwtMgr := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - h := &AuthHandler{JWT: jwtMgr, DB: mockDB} + h := &AuthHandler{JWT: jwtMgr, AuthStore: ms} body := `{"email":"admin@example.com","password":"password123","display_name":"Admin"}` req := httptest.NewRequest("POST", "/auth/register", strings.NewReader(body)) @@ -753,13 +724,8 @@ func TestRegister_WhenOwnerConstraintViolated_RetriesAsMaintainer(t *testing.T) if role := user["role"]; role != "maintainer" { t.Errorf("role = %q, want %q", role, "maintainer") } - if callCount != 3 { - t.Errorf("expected 3 QueryRow calls (email check + first INSERT + retry), got %d", callCount) - } } -// TestRegister_RoleAssignment verifies that the first registered user receives -// the 'owner' role and all subsequent users receive 'maintainer'. func TestRegister_RoleAssignment(t *testing.T) { tests := []struct { name string @@ -771,36 +737,20 @@ func TestRegister_RoleAssignment(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - callCount := 0 - mockDB := &mockAuthDB{ - queryRowFn: func(ctx context.Context, sql string, args ...any) pgx.Row { - callCount++ - switch callCount { - case 1: - // Email existence check → not taken - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*bool)) = false - return nil - }} - case 2: - // INSERT ... RETURNING id, role - return &mockRow{scanFn: func(dest ...any) error { - *(dest[0].(*string)) = "user-uuid" - *(dest[1].(*string)) = tc.role - return nil - }} - default: - t.Errorf("unexpected QueryRow call #%d", callCount) - return &mockRow{scanFn: func(dest ...any) error { return pgx.ErrNoRows }} - } + ms := &mockAuthStore{ + emailExistsFn: func(_ context.Context, _ string) (bool, error) { + return false, nil }, - execFn: func(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { - return pgconn.NewCommandTag("INSERT 1"), nil + createUserFn: func(_ context.Context, _, _, _, _ string) (string, string, error) { + return "user-uuid", tc.role, nil + }, + createSessionFn: func(_ context.Context, _ string, _ string, _ string, _ net.IP, _ time.Time) error { + return nil }, } jwtMgr := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - h := &AuthHandler{JWT: jwtMgr, DB: mockDB} + h := &AuthHandler{JWT: jwtMgr, AuthStore: ms} body := `{"email":"user@example.com","password":"password123","display_name":"User"}` req := httptest.NewRequest("POST", "/auth/register", strings.NewReader(body)) @@ -828,3 +778,208 @@ func TestRegister_RoleAssignment(t *testing.T) { }) } } + +// Store-aware handler tests + +func TestAuthHandler_Login_WithStore_Success(t *testing.T) { + hash, err := auth.HashPassword("password123") + if err != nil { + t.Fatalf("HashPassword: %v", err) + } + + callCount := 0 + ms := &mockAuthStore{ + getUserByEmailFn: func(_ context.Context, email string) (*model.User, error) { + callCount++ + return &model.User{ + ID: "uid-1", + Email: email, + PasswordHash: hash, + DisplayName: "Test", + Role: "maintainer", + }, nil + }, + getPrimaryTeamIDFn: func(_ context.Context, _ string) (string, error) { + return "team-1", nil + }, + createSessionFn: func(_ context.Context, _ string, _ string, _ string, _ net.IP, _ time.Time) error { + return nil + }, + } + + jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) + h := &AuthHandler{JWT: jwt, AuthStore: ms} + + body := `{"email":"test@test.com","password":"password123"}` + req := httptest.NewRequest("POST", "/auth/login", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.Login(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Login with store: status = %d, want %d (body: %s)", w.Code, http.StatusOK, w.Body.String()) + } + if callCount != 1 { + t.Errorf("expected 1 GetUserByEmail call, got %d", callCount) + } +} + +func TestAuthHandler_Login_WithStore_InvalidCredentials(t *testing.T) { + hash, err := auth.HashPassword("correctpassword") + if err != nil { + t.Fatalf("HashPassword: %v", err) + } + + ms := &mockAuthStore{ + getUserByEmailFn: func(_ context.Context, _ string) (*model.User, error) { + return &model.User{ + ID: "uid-1", + Email: "test@test.com", + PasswordHash: hash, + DisplayName: "Test", + Role: "maintainer", + }, nil + }, + } + + jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) + h := &AuthHandler{JWT: jwt, AuthStore: ms} + + body := `{"email":"test@test.com","password":"wrongpassword"}` + req := httptest.NewRequest("POST", "/auth/login", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.Login(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Login wrong password: status = %d, want %d", w.Code, http.StatusUnauthorized) + } +} + +func TestAuthHandler_Login_WithStore_UserNotFound(t *testing.T) { + ms := &mockAuthStore{ + getUserByEmailFn: func(_ context.Context, _ string) (*model.User, error) { + return nil, pgx.ErrNoRows + }, + } + + jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) + h := &AuthHandler{JWT: jwt, AuthStore: ms} + + body := `{"email":"nobody@test.com","password":"password123"}` + req := httptest.NewRequest("POST", "/auth/login", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.Login(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Login user not found: status = %d, want %d", w.Code, http.StatusUnauthorized) + } +} + +func TestAuthHandler_Refresh_WithStore_Expired(t *testing.T) { + ms := &mockAuthStore{ + getSessionByRefreshTokenFn: func(_ context.Context, _ string) (*store.SessionInfo, error) { + return &store.SessionInfo{ + ID: "sess-1", + UserID: "uid-1", + ExpiresAt: time.Now().Add(-1 * time.Hour), + }, nil + }, + deleteSessionFn: func(_ context.Context, _ string) error { + return nil + }, + } + + jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) + h := &AuthHandler{JWT: jwt, AuthStore: ms} + + req := httptest.NewRequest("POST", "/auth/refresh", nil) + req.AddCookie(&http.Cookie{Name: "refresh_token", Value: "expired-token"}) + w := httptest.NewRecorder() + + h.Refresh(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Refresh expired: status = %d, want %d", w.Code, http.StatusUnauthorized) + } +} + +func TestAuthHandler_Logout_WithStore(t *testing.T) { + deleteCalled := false + ms := &mockAuthStore{ + deleteSessionByRefreshFn: func(_ context.Context, _ string) error { + deleteCalled = true + return nil + }, + } + + jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) + h := &AuthHandler{JWT: jwt, AuthStore: ms} + + req := httptest.NewRequest("POST", "/auth/logout", nil) + req.AddCookie(&http.Cookie{Name: "refresh_token", Value: "some-token"}) + w := httptest.NewRecorder() + + h.Logout(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Logout: status = %d, want %d", w.Code, http.StatusOK) + } + if !deleteCalled { + t.Error("expected DeleteSessionByRefreshToken to be called") + } +} + +func TestAuthHandler_UpdateMe_WithStore(t *testing.T) { + ms := &mockAuthStore{ + updateProfileFn: func(_ context.Context, userID, displayName string) (*model.User, error) { + return &model.User{ + ID: userID, + Email: "test@test.com", + DisplayName: displayName, + Role: "maintainer", + }, nil + }, + } + + jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) + h := &AuthHandler{JWT: jwt, AuthStore: ms} + + body := `{"display_name":"New Name"}` + req := httptest.NewRequest("PATCH", "/api/v1/auth/me", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req = req.WithContext(auth.SetClaims(req.Context(), &auth.Claims{UserID: "user-123"})) + w := httptest.NewRecorder() + + h.UpdateMe(w, req) + + if w.Code != http.StatusOK { + t.Errorf("UpdateMe: status = %d, want %d (body: %s)", w.Code, http.StatusOK, w.Body.String()) + } +} + +func TestAuthHandler_Register_WithStore_EmailExists(t *testing.T) { + ms := &mockAuthStore{ + emailExistsFn: func(_ context.Context, _ string) (bool, error) { + return true, nil + }, + } + + jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) + h := &AuthHandler{JWT: jwt, AuthStore: ms} + + body := `{"email":"taken@test.com","password":"password123","display_name":"Test"}` + req := httptest.NewRequest("POST", "/auth/register", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.Register(w, req) + + if w.Code != http.StatusConflict { + t.Errorf("Register email taken: status = %d, want %d", w.Code, http.StatusConflict) + } +} diff --git a/internal/handler/executions.go b/internal/handler/executions.go index b0048cc8..25d18efe 100644 --- a/internal/handler/executions.go +++ b/internal/handler/executions.go @@ -4,19 +4,15 @@ import ( "context" "encoding/json" "net/http" - "strconv" "time" "github.com/go-chi/chi/v5" - "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/rs/zerolog/log" "github.com/scaledtest/scaledtest/internal/auth" - "github.com/scaledtest/scaledtest/internal/db" "github.com/scaledtest/scaledtest/internal/k8s" - "github.com/scaledtest/scaledtest/internal/model" "github.com/scaledtest/scaledtest/internal/sanitize" "github.com/scaledtest/scaledtest/internal/store" "github.com/scaledtest/scaledtest/internal/webhook" @@ -25,18 +21,14 @@ import ( // ExecutionsHandler handles test execution endpoints. type ExecutionsHandler struct { - DB *db.Pool - Hub *ws.Hub // WebSocket hub for real-time broadcasting (optional) - AuditStore *store.AuditStore // optional; nil means no audit logging - K8s *k8s.Client // optional; nil means K8s job launch is disabled - WorkerImage string // default container image for test workers - WorkerToken string // auth token workers use to report back - APIBaseURL string // base URL workers use to call the API - Webhooks *webhook.Notifier // optional; nil means no webhook dispatch - - // ownsExecFunc overrides ownsExecution for testing. If nil, the DB-based - // implementation is used. Callers must not set this outside of tests. - ownsExecFunc func(ctx context.Context, executionID, teamID string) (bool, error) + ExecStore executionsStore + Hub *ws.Hub + AuditStore *store.AuditStore + K8s *k8s.Client + WorkerImage string + WorkerToken string + APIBaseURL string + Webhooks *webhook.Notifier } // CreateExecutionRequest is the request body for creating a test execution. @@ -54,55 +46,18 @@ func (h *ExecutionsHandler) List(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.ExecStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } - // Pagination limit, offset := parsePagination(r) - rows, err := h.DB.Query(r.Context(), - `SELECT id, team_id, status, command, config, report_id, k8s_job_name, k8s_pod_name, - error_msg, started_at, finished_at, created_at, updated_at - FROM test_executions - WHERE team_id = $1 - ORDER BY created_at DESC - LIMIT $2 OFFSET $3`, - claims.TeamID, limit, offset) + executions, total, err := h.ExecStore.List(r.Context(), claims.TeamID, limit, offset) if err != nil { Error(w, http.StatusInternalServerError, "failed to query executions") return } - defer rows.Close() - - executions := []model.TestExecution{} - for rows.Next() { - var e model.TestExecution - if err := rows.Scan( - &e.ID, &e.TeamID, &e.Status, &e.Command, &e.Config, &e.ReportID, - &e.K8sJobName, &e.K8sPodName, &e.ErrorMsg, &e.StartedAt, - &e.FinishedAt, &e.CreatedAt, &e.UpdatedAt, - ); err != nil { - Error(w, http.StatusInternalServerError, "failed to scan execution") - return - } - executions = append(executions, e) - } - if err := rows.Err(); err != nil { - Error(w, http.StatusInternalServerError, "failed to iterate executions") - return - } - - // Get total count for pagination - var total int - err = h.DB.QueryRow(r.Context(), - `SELECT COUNT(*) FROM test_executions WHERE team_id = $1`, - claims.TeamID).Scan(&total) - if err != nil { - Error(w, http.StatusInternalServerError, "failed to count executions") - return - } JSON(w, http.StatusOK, map[string]interface{}{ "executions": executions, @@ -124,17 +79,15 @@ func (h *ExecutionsHandler) Create(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.ExecStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } - // Sanitize user-provided strings req.Command = sanitize.String(req.Command) req.Image = sanitize.String(req.Image) req.EnvVars = sanitize.StringMap(req.EnvVars) - // Build config JSON from image and env vars var configJSON []byte if req.Image != "" || len(req.EnvVars) > 0 { cfg := map[string]interface{}{} @@ -152,19 +105,12 @@ func (h *ExecutionsHandler) Create(w http.ResponseWriter, r *http.Request) { } } - id := uuid.New().String() - now := time.Now() - - _, err := h.DB.Exec(r.Context(), - `INSERT INTO test_executions (id, team_id, status, command, config, created_at, updated_at) - VALUES ($1, $2, 'pending', $3, $4, $5, $5)`, - id, claims.TeamID, req.Command, configJSON, now) + id, err := h.ExecStore.Create(r.Context(), claims.TeamID, req.Command, configJSON) if err != nil { Error(w, http.StatusInternalServerError, "failed to create execution") return } - // Launch K8s job if client is configured if h.K8s != nil { image := req.Image if image == "" { @@ -182,16 +128,13 @@ func (h *ExecutionsHandler) Create(w http.ResponseWriter, r *http.Request) { } if _, err := h.K8s.CreateJob(r.Context(), jobCfg); err != nil { log.Error().Err(err).Str("execution_id", id).Msg("failed to launch k8s job") - h.DB.Exec(r.Context(), - `UPDATE test_executions SET status = 'failed', error_msg = $1, updated_at = $2 WHERE id = $3`, - "job launch failed: "+err.Error(), time.Now(), id) + _ = h.ExecStore.MarkFailed(r.Context(), id, "job launch failed: "+err.Error(), time.Now()) Error(w, http.StatusInternalServerError, "execution created but job launch failed") return } - // Store K8s job name on the execution record - h.DB.Exec(r.Context(), - `UPDATE test_executions SET k8s_job_name = $1, updated_at = $2 WHERE id = $3`, - jobName, time.Now(), id) + if err := h.ExecStore.SetK8sJobName(r.Context(), id, jobName, time.Now()); err != nil { + log.Error().Err(err).Str("execution_id", id).Msg("failed to store k8s job name") + } } if h.AuditStore != nil { @@ -227,12 +170,12 @@ func (h *ExecutionsHandler) Get(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.ExecStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } - e, err := getExecution(r.Context(), h.DB, executionID, claims.TeamID) + e, err := h.ExecStore.Get(r.Context(), executionID, claims.TeamID) if err == pgx.ErrNoRows { Error(w, http.StatusNotFound, "execution not found") return @@ -259,31 +202,24 @@ func (h *ExecutionsHandler) Cancel(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.ExecStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } now := time.Now() - tag, err := h.DB.Exec(r.Context(), - `UPDATE test_executions - SET status = 'cancelled', finished_at = $1, updated_at = $1 - WHERE id = $2 AND team_id = $3 AND status IN ('pending', 'running')`, - now, executionID, claims.TeamID) + rowsAffected, err := h.ExecStore.Cancel(r.Context(), executionID, claims.TeamID, now) if err != nil { Error(w, http.StatusInternalServerError, "failed to cancel execution") return } - if tag.RowsAffected() == 0 { + if rowsAffected == 0 { Error(w, http.StatusNotFound, "execution not found or not cancellable") return } - // Clean up the K8s job if one was launched if h.K8s != nil { - var jobName *string - _ = h.DB.QueryRow(r.Context(), - `SELECT k8s_job_name FROM test_executions WHERE id = $1`, executionID).Scan(&jobName) + jobName, _ := h.ExecStore.GetK8sJobName(r.Context(), executionID) if jobName != nil && *jobName != "" { if err := h.K8s.DeleteJob(r.Context(), *jobName); err != nil { log.Error().Err(err).Str("job", *jobName).Msg("failed to delete k8s job on cancel") @@ -335,60 +271,34 @@ func (h *ExecutionsHandler) UpdateStatus(w http.ResponseWriter, r *http.Request) return } - if h.DB == nil { + if h.ExecStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } - // Sanitize user-provided strings req.ErrorMsg = sanitize.String(req.ErrorMsg) now := time.Now() - - // Build dynamic update - query := `UPDATE test_executions SET status = $1, updated_at = $2` - args := []interface{}{req.Status, now} - argIdx := 3 - - if req.Status == "running" { - query += `, started_at = COALESCE(started_at, $` + strconv.Itoa(argIdx) + `)` - args = append(args, now) - argIdx++ - } - - if req.Status == "completed" || req.Status == "failed" || req.Status == "cancelled" { - query += `, finished_at = $` + strconv.Itoa(argIdx) - args = append(args, now) - argIdx++ - } - + var errorMsg *string if req.ErrorMsg != "" { - query += `, error_msg = $` + strconv.Itoa(argIdx) - args = append(args, req.ErrorMsg) - argIdx++ + errorMsg = &req.ErrorMsg } - - query += ` WHERE id = $` + strconv.Itoa(argIdx) + ` AND team_id = $` + strconv.Itoa(argIdx+1) - args = append(args, executionID, claims.TeamID) - - tag, err := h.DB.Exec(r.Context(), query, args...) + rowsAffected, err := h.ExecStore.UpdateStatus(r.Context(), executionID, claims.TeamID, req.Status, now, errorMsg) if err != nil { Error(w, http.StatusInternalServerError, "failed to update execution status") return } - if tag.RowsAffected() == 0 { + if rowsAffected == 0 { Error(w, http.StatusNotFound, "execution not found") return } - // Broadcast status change via WebSocket if h.Hub != nil { h.Hub.BroadcastExecutionStatus(executionID, req.Status, map[string]interface{}{ "error_msg": req.ErrorMsg, }) } - // Audit terminal state transitions (completed/failed). if h.AuditStore != nil && (req.Status == "completed" || req.Status == "failed") { meta := map[string]interface{}{"status": req.Status} if req.ErrorMsg != "" { @@ -409,7 +319,6 @@ func (h *ExecutionsHandler) UpdateStatus(w http.ResponseWriter, r *http.Request) }) } - // Fire webhooks for terminal execution states. if req.Status == "completed" || req.Status == "failed" { eventType := webhook.EventExecutionCompleted if req.Status == "failed" { @@ -431,41 +340,9 @@ func (h *ExecutionsHandler) UpdateStatus(w http.ResponseWriter, r *http.Request) }) } -// getExecution fetches a single execution by ID, scoped to team. -func getExecution(ctx context.Context, pool *db.Pool, id, teamID string) (*model.TestExecution, error) { - var e model.TestExecution - err := pool.QueryRow(ctx, - `SELECT id, team_id, status, command, config, report_id, k8s_job_name, k8s_pod_name, - error_msg, started_at, finished_at, created_at, updated_at - FROM test_executions - WHERE id = $1 AND team_id = $2`, - id, teamID).Scan( - &e.ID, &e.TeamID, &e.Status, &e.Command, &e.Config, &e.ReportID, - &e.K8sJobName, &e.K8sPodName, &e.ErrorMsg, &e.StartedAt, - &e.FinishedAt, &e.CreatedAt, &e.UpdatedAt, - ) - if err != nil { - return nil, err - } - return &e, nil -} - // ownsExecution checks whether the given execution belongs to the specified team. -// Returns (false, nil) if the execution does not belong to the team, -// (false, err) on database errors (distinguishing 404 from 500), -// and (true, nil) if the execution belongs to the team. func (h *ExecutionsHandler) ownsExecution(ctx context.Context, executionID, teamID string) (bool, error) { - if h.ownsExecFunc != nil { - return h.ownsExecFunc(ctx, executionID, teamID) - } - var exists bool - err := h.DB.QueryRow(ctx, - `SELECT EXISTS(SELECT 1 FROM test_executions WHERE id = $1 AND team_id = $2)`, - executionID, teamID).Scan(&exists) - if err != nil { - return false, err - } - return exists, nil + return h.ExecStore.Exists(ctx, executionID, teamID) } func (h *ExecutionsHandler) requireWorkerCallback(w http.ResponseWriter, r *http.Request) (*auth.Claims, string, bool) { @@ -481,7 +358,7 @@ func (h *ExecutionsHandler) requireWorkerCallback(w http.ResponseWriter, r *http return nil, "", false } - if h.DB == nil { + if h.ExecStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return nil, "", false } @@ -523,7 +400,6 @@ func (h *ExecutionsHandler) ReportProgress(w http.ResponseWriter, r *http.Reques return } - // Broadcast progress via WebSocket if h.Hub != nil { h.Hub.BroadcastProgress(executionID, map[string]interface{}{ "passed": req.Passed, @@ -566,7 +442,6 @@ func (h *ExecutionsHandler) ReportTestResult(w http.ResponseWriter, r *http.Requ return } - // Broadcast individual test result via WebSocket if h.Hub != nil { h.Hub.BroadcastTestResult(executionID, map[string]interface{}{ "name": req.Name, @@ -607,7 +482,6 @@ func (h *ExecutionsHandler) ReportWorkerStatus(w http.ResponseWriter, r *http.Re return } - // Broadcast worker status via WebSocket if h.Hub != nil { h.Hub.BroadcastWorkerStatus(executionID, map[string]interface{}{ "worker_id": req.WorkerID, diff --git a/internal/handler/executions_test.go b/internal/handler/executions_test.go index d97d6c77..028437f8 100644 --- a/internal/handler/executions_test.go +++ b/internal/handler/executions_test.go @@ -8,10 +8,53 @@ import ( "net/http/httptest" "strings" "testing" + "time" - "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5" + + "github.com/scaledtest/scaledtest/internal/model" ) +type mockExecutionsStore struct { + listFn func(ctx context.Context, teamID string, limit, offset int) ([]model.TestExecution, int, error) + createFn func(ctx context.Context, teamID, command string, configJSON []byte) (string, error) + getFn func(ctx context.Context, id, teamID string) (*model.TestExecution, error) + cancelFn func(ctx context.Context, id, teamID string, now time.Time) (int64, error) + updateStatusFn func(ctx context.Context, id, teamID, status string, now time.Time, errorMsg *string) (int64, error) + existsFn func(ctx context.Context, id, teamID string) (bool, error) + getK8sJobNameFn func(ctx context.Context, id string) (*string, error) + setK8sJobNameFn func(ctx context.Context, id, jobName string, now time.Time) error + markFailedFn func(ctx context.Context, id, errorMsg string, now time.Time) error +} + +func (m *mockExecutionsStore) List(ctx context.Context, teamID string, limit, offset int) ([]model.TestExecution, int, error) { + return m.listFn(ctx, teamID, limit, offset) +} +func (m *mockExecutionsStore) Create(ctx context.Context, teamID, command string, configJSON []byte) (string, error) { + return m.createFn(ctx, teamID, command, configJSON) +} +func (m *mockExecutionsStore) Get(ctx context.Context, id, teamID string) (*model.TestExecution, error) { + return m.getFn(ctx, id, teamID) +} +func (m *mockExecutionsStore) Cancel(ctx context.Context, id, teamID string, now time.Time) (int64, error) { + return m.cancelFn(ctx, id, teamID, now) +} +func (m *mockExecutionsStore) UpdateStatus(ctx context.Context, id, teamID, status string, now time.Time, errorMsg *string) (int64, error) { + return m.updateStatusFn(ctx, id, teamID, status, now, errorMsg) +} +func (m *mockExecutionsStore) Exists(ctx context.Context, id, teamID string) (bool, error) { + return m.existsFn(ctx, id, teamID) +} +func (m *mockExecutionsStore) GetK8sJobName(ctx context.Context, id string) (*string, error) { + return m.getK8sJobNameFn(ctx, id) +} +func (m *mockExecutionsStore) SetK8sJobName(ctx context.Context, id, jobName string, now time.Time) error { + return m.setK8sJobNameFn(ctx, id, jobName, now) +} +func (m *mockExecutionsStore) MarkFailed(ctx context.Context, id, errorMsg string, now time.Time) error { + return m.markFailedFn(ctx, id, errorMsg, now) +} + func TestListExecutions_Unauthorized(t *testing.T) { h := &ExecutionsHandler{} w := httptest.NewRecorder() @@ -25,7 +68,7 @@ func TestListExecutions_Unauthorized(t *testing.T) { } func TestListExecutions_NoDB(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/executions", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -50,7 +93,7 @@ func TestCreateExecution_Unauthorized(t *testing.T) { } func TestCreateExecution_InvalidBody(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/executions", strings.NewReader(`{invalid}`)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -63,7 +106,7 @@ func TestCreateExecution_InvalidBody(t *testing.T) { } func TestCreateExecution_MissingCommand(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/executions", strings.NewReader(`{}`)) r.Header.Set("Content-Type", "application/json") @@ -77,7 +120,7 @@ func TestCreateExecution_MissingCommand(t *testing.T) { } func TestCreateExecution_NoDB(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/executions", strings.NewReader(`{"command":"npm test"}`)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -117,7 +160,7 @@ func TestGetExecution_MissingID(t *testing.T) { } func TestGetExecution_NoDB(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/executions/abc", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -158,7 +201,7 @@ func TestCancelExecution_MissingID(t *testing.T) { } func TestCancelExecution_NoDB(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("DELETE", "/api/v1/executions/abc", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -185,7 +228,7 @@ func TestUpdateStatus_Unauthorized(t *testing.T) { } func TestUpdateStatus_InvalidBody(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("PUT", "/api/v1/executions/abc/status", strings.NewReader(`{invalid}`)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -199,7 +242,7 @@ func TestUpdateStatus_InvalidBody(t *testing.T) { } func TestUpdateStatus_InvalidStatus(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("PUT", "/api/v1/executions/abc/status", strings.NewReader(`{"status":"invalid"}`)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -213,7 +256,7 @@ func TestUpdateStatus_InvalidStatus(t *testing.T) { } func TestUpdateStatus_NoDB(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("PUT", "/api/v1/executions/abc/status", strings.NewReader(`{"status":"running"}`)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -241,7 +284,7 @@ func TestReportProgress_Unauthorized(t *testing.T) { } func TestReportProgress_NoDB(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/executions/abc/progress", strings.NewReader(`{"total":1,"passed":1}`)) r.Header.Set("Content-Type", "application/json") @@ -270,7 +313,7 @@ func TestReportTestResult_Unauthorized(t *testing.T) { } func TestReportTestResult_NoDB(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/executions/abc/test-result", strings.NewReader(`{"name":"test","status":"passed"}`)) r.Header.Set("Content-Type", "application/json") @@ -299,7 +342,7 @@ func TestReportWorkerStatus_Unauthorized(t *testing.T) { } func TestReportWorkerStatus_NoDB(t *testing.T) { - h := &ExecutionsHandler{DB: nil} + h := &ExecutionsHandler{ExecStore: nil} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/executions/abc/worker-status", strings.NewReader(`{"worker_id":"w1","status":"running"}`)) r.Header.Set("Content-Type", "application/json") @@ -314,12 +357,12 @@ func TestReportWorkerStatus_NoDB(t *testing.T) { } func TestReportProgress_CrossTeam_Forbidden(t *testing.T) { - h := &ExecutionsHandler{ - DB: new(pgxpool.Pool), - ownsExecFunc: func(_ context.Context, executionID, teamID string) (bool, error) { + ms := &mockExecutionsStore{ + existsFn: func(_ context.Context, _, _ string) (bool, error) { return false, nil }, } + h := &ExecutionsHandler{ExecStore: ms} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/executions/exec-1/progress", strings.NewReader(`{"total":1,"passed":1}`)) r.Header.Set("Content-Type", "application/json") @@ -334,12 +377,12 @@ func TestReportProgress_CrossTeam_Forbidden(t *testing.T) { } func TestReportTestResult_CrossTeam_Forbidden(t *testing.T) { - h := &ExecutionsHandler{ - DB: new(pgxpool.Pool), - ownsExecFunc: func(_ context.Context, executionID, teamID string) (bool, error) { + ms := &mockExecutionsStore{ + existsFn: func(_ context.Context, _, _ string) (bool, error) { return false, nil }, } + h := &ExecutionsHandler{ExecStore: ms} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/executions/exec-1/test-result", strings.NewReader(`{"name":"test","status":"passed"}`)) r.Header.Set("Content-Type", "application/json") @@ -354,12 +397,12 @@ func TestReportTestResult_CrossTeam_Forbidden(t *testing.T) { } func TestReportWorkerStatus_CrossTeam_Forbidden(t *testing.T) { - h := &ExecutionsHandler{ - DB: new(pgxpool.Pool), - ownsExecFunc: func(_ context.Context, executionID, teamID string) (bool, error) { + ms := &mockExecutionsStore{ + existsFn: func(_ context.Context, _, _ string) (bool, error) { return false, nil }, } + h := &ExecutionsHandler{ExecStore: ms} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/executions/exec-1/worker-status", strings.NewReader(`{"worker_id":"w1","status":"running"}`)) r.Header.Set("Content-Type", "application/json") @@ -374,12 +417,12 @@ func TestReportWorkerStatus_CrossTeam_Forbidden(t *testing.T) { } func TestReportProgress_DBError_Returns500(t *testing.T) { - h := &ExecutionsHandler{ - DB: new(pgxpool.Pool), - ownsExecFunc: func(_ context.Context, executionID, teamID string) (bool, error) { + ms := &mockExecutionsStore{ + existsFn: func(_ context.Context, _, _ string) (bool, error) { return false, fmt.Errorf("connection refused") }, } + h := &ExecutionsHandler{ExecStore: ms} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/executions/exec-1/progress", strings.NewReader(`{"total":1,"passed":1}`)) r.Header.Set("Content-Type", "application/json") @@ -409,3 +452,176 @@ func TestErrorResponse_Format(t *testing.T) { t.Errorf("Error message: got %q, want %q", resp["error"], "test error") } } + +// Store-aware handler tests + +func TestExecutionsHandler_List_WithStore(t *testing.T) { + ms := &mockExecutionsStore{ + listFn: func(_ context.Context, _ string, _, _ int) ([]model.TestExecution, int, error) { + return []model.TestExecution{{ID: "exec-1", Status: "pending"}}, 1, nil + }, + } + h := &ExecutionsHandler{ExecStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/api/v1/executions", nil) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + + h.List(w, r) + + if w.Code != http.StatusOK { + t.Errorf("List with store: status = %d, want %d", w.Code, http.StatusOK) + } +} + +func TestExecutionsHandler_Get_WithStore_Found(t *testing.T) { + ms := &mockExecutionsStore{ + getFn: func(_ context.Context, id, _ string) (*model.TestExecution, error) { + return &model.TestExecution{ID: id, Status: "running"}, nil + }, + } + h := &ExecutionsHandler{ExecStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/api/v1/executions/exec-1", nil) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "executionID", "exec-1") + + h.Get(w, r) + + if w.Code != http.StatusOK { + t.Errorf("Get with store: status = %d, want %d", w.Code, http.StatusOK) + } +} + +func TestExecutionsHandler_Get_WithStore_NotFound(t *testing.T) { + ms := &mockExecutionsStore{ + getFn: func(_ context.Context, _, _ string) (*model.TestExecution, error) { + return nil, pgx.ErrNoRows + }, + } + h := &ExecutionsHandler{ExecStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/api/v1/executions/nonexistent", nil) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "executionID", "nonexistent") + + h.Get(w, r) + + if w.Code != http.StatusNotFound { + t.Errorf("Get not found: status = %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestExecutionsHandler_Cancel_WithStore(t *testing.T) { + ms := &mockExecutionsStore{ + cancelFn: func(_ context.Context, _, _ string, _ time.Time) (int64, error) { + return 1, nil + }, + } + h := &ExecutionsHandler{ExecStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("DELETE", "/api/v1/executions/exec-1", nil) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "executionID", "exec-1") + + h.Cancel(w, r) + + if w.Code != http.StatusOK { + t.Errorf("Cancel with store: status = %d, want %d", w.Code, http.StatusOK) + } +} + +func TestExecutionsHandler_Cancel_WithStore_NotFound(t *testing.T) { + ms := &mockExecutionsStore{ + cancelFn: func(_ context.Context, _, _ string, _ time.Time) (int64, error) { + return 0, nil + }, + } + h := &ExecutionsHandler{ExecStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("DELETE", "/api/v1/executions/nonexistent", nil) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "executionID", "nonexistent") + + h.Cancel(w, r) + + if w.Code != http.StatusNotFound { + t.Errorf("Cancel not found: status = %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestExecutionsHandler_Create_WithStore(t *testing.T) { + ms := &mockExecutionsStore{ + createFn: func(_ context.Context, _, _ string, _ []byte) (string, error) { + return "exec-new", nil + }, + } + h := &ExecutionsHandler{ExecStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/api/v1/executions", strings.NewReader(`{"command":"npm test"}`)) + r.Header.Set("Content-Type", "application/json") + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + + h.Create(w, r) + + if w.Code != http.StatusCreated { + t.Errorf("Create with store: status = %d, want %d (body: %s)", w.Code, http.StatusCreated, w.Body.String()) + } +} + +func TestExecutionsHandler_UpdateStatus_WithStore(t *testing.T) { + ms := &mockExecutionsStore{ + updateStatusFn: func(_ context.Context, _, _, _ string, _ time.Time, _ *string) (int64, error) { + return 1, nil + }, + } + h := &ExecutionsHandler{ExecStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("PUT", "/api/v1/executions/exec-1/status", strings.NewReader(`{"status":"running"}`)) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "executionID", "exec-1") + + h.UpdateStatus(w, r) + + if w.Code != http.StatusOK { + t.Errorf("UpdateStatus with store: status = %d, want %d (body: %s)", w.Code, http.StatusOK, w.Body.String()) + } +} + +func TestExecutionsHandler_UpdateStatus_WithStore_NotFound(t *testing.T) { + ms := &mockExecutionsStore{ + updateStatusFn: func(_ context.Context, _, _, _ string, _ time.Time, _ *string) (int64, error) { + return 0, nil + }, + } + h := &ExecutionsHandler{ExecStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("PUT", "/api/v1/executions/nonexistent/status", strings.NewReader(`{"status":"running"}`)) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "executionID", "nonexistent") + + h.UpdateStatus(w, r) + + if w.Code != http.StatusNotFound { + t.Errorf("UpdateStatus not found: status = %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestExecutionsHandler_ReportProgress_Owned(t *testing.T) { + ms := &mockExecutionsStore{ + existsFn: func(_ context.Context, _, _ string) (bool, error) { + return true, nil + }, + } + h := &ExecutionsHandler{ExecStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", "/api/v1/executions/exec-1/progress", strings.NewReader(`{"total":1,"passed":1}`)) + r.Header.Set("Content-Type", "application/json") + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "executionID", "exec-1") + + h.ReportProgress(w, r) + + if w.Code != http.StatusOK { + t.Errorf("ReportProgress owned: status = %d, want %d", w.Code, http.StatusOK) + } +} diff --git a/internal/handler/invitations.go b/internal/handler/invitations.go index b5c96877..de7f5b44 100644 --- a/internal/handler/invitations.go +++ b/internal/handler/invitations.go @@ -13,7 +13,6 @@ import ( "github.com/rs/zerolog/log" "github.com/scaledtest/scaledtest/internal/auth" - "github.com/scaledtest/scaledtest/internal/db" "github.com/scaledtest/scaledtest/internal/mailer" "github.com/scaledtest/scaledtest/internal/model" "github.com/scaledtest/scaledtest/internal/sanitize" @@ -36,12 +35,12 @@ type invitationStore interface { GetByTokenHash(ctx context.Context, tokenHash string) (*model.Invitation, error) Delete(ctx context.Context, teamID, id string) error AcceptInvitation(ctx context.Context, invID, email, passwordHash, displayName, role, teamID string) (string, error) + GetTeamName(ctx context.Context, teamID string) (string, error) } // InvitationsHandler handles invitation endpoints. type InvitationsHandler struct { Store invitationStore - DB *db.Pool Mailer mailer.Mailer BaseURL string AuditStore auditLogger @@ -88,7 +87,7 @@ func (h *InvitationsHandler) Create(w http.ResponseWriter, r *http.Request) { return } - if h.Store == nil || h.DB == nil { + if h.Store == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } @@ -182,7 +181,7 @@ func (h *InvitationsHandler) Preview(w http.ResponseWriter, r *http.Request) { return } - if h.Store == nil || h.DB == nil { + if h.Store == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } @@ -204,9 +203,7 @@ func (h *InvitationsHandler) Preview(w http.ResponseWriter, r *http.Request) { return } - // Look up team name for display - var teamName string - err = h.DB.QueryRow(r.Context(), "SELECT name FROM teams WHERE id = $1", inv.TeamID).Scan(&teamName) + teamName, err := h.Store.GetTeamName(r.Context(), inv.TeamID) if err != nil { teamName = "Unknown" } diff --git a/internal/handler/invitations_test.go b/internal/handler/invitations_test.go index ca68d04a..17b2d9cf 100644 --- a/internal/handler/invitations_test.go +++ b/internal/handler/invitations_test.go @@ -10,8 +10,6 @@ import ( "testing" "time" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/scaledtest/scaledtest/internal/auth" "github.com/scaledtest/scaledtest/internal/model" "github.com/scaledtest/scaledtest/internal/store" @@ -19,12 +17,14 @@ import ( // mockInvitationStore is a test double for invitationStore. type mockInvitationStore struct { - inv *model.Invitation // returned by Create - err error // returned by Create - tokenInv *model.Invitation // returned by GetByTokenHash - tokenErr error // returned by GetByTokenHash - acceptedUserID string // returned by AcceptInvitation - acceptErr error // returned by AcceptInvitation + inv *model.Invitation + err error + tokenInv *model.Invitation + tokenErr error + acceptedUserID string + acceptErr error + teamName string + teamNameErr error } func (m *mockInvitationStore) Create(_ context.Context, _, _, _, _, _ string, _ time.Time) (*model.Invitation, error) { @@ -47,6 +47,10 @@ func (m *mockInvitationStore) AcceptInvitation(_ context.Context, _, _, _, _, _, return m.acceptedUserID, m.acceptErr } +func (m *mockInvitationStore) GetTeamName(_ context.Context, _ string) (string, error) { + return m.teamName, m.teamNameErr +} + // mockMailer is a test double for mailer.Mailer. type mockMailer struct { called bool @@ -117,7 +121,7 @@ func TestCreateInvitation_InvalidRole(t *testing.T) { } func TestCreateInvitation_NoDB(t *testing.T) { - h := &InvitationsHandler{Store: nil, DB: nil} + h := &InvitationsHandler{Store: nil} w := httptest.NewRecorder() body := `{"email":"test@example.com","role":"readonly"}` r := httptest.NewRequest("POST", "/api/v1/teams/t1/invitations", strings.NewReader(body)) @@ -127,7 +131,7 @@ func TestCreateInvitation_NoDB(t *testing.T) { h.Create(w, r) if w.Code != http.StatusServiceUnavailable { - t.Errorf("Create without DB: got %d, want %d", w.Code, http.StatusServiceUnavailable) + t.Errorf("Create without store: got %d, want %d", w.Code, http.StatusServiceUnavailable) } } @@ -177,7 +181,7 @@ func TestPreviewInvitation_MissingToken(t *testing.T) { } func TestPreviewInvitation_NoDB(t *testing.T) { - h := &InvitationsHandler{Store: nil, DB: nil} + h := &InvitationsHandler{Store: nil} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/invitations/inv_abc", nil) r = testWithChiParam(r, "token", "inv_abc") @@ -185,7 +189,7 @@ func TestPreviewInvitation_NoDB(t *testing.T) { h.Preview(w, r) if w.Code != http.StatusServiceUnavailable { - t.Errorf("Preview without DB: got %d, want %d", w.Code, http.StatusServiceUnavailable) + t.Errorf("Preview without store: got %d, want %d", w.Code, http.StatusServiceUnavailable) } } @@ -203,7 +207,7 @@ func TestAcceptInvitation_MissingToken(t *testing.T) { } func TestAcceptInvitation_NoDB(t *testing.T) { - h := &InvitationsHandler{Store: nil, DB: nil} + h := &InvitationsHandler{Store: nil} w := httptest.NewRecorder() body := `{"password":"password123","display_name":"Test User"}` r := httptest.NewRequest("POST", "/api/v1/invitations/inv_abc/accept", strings.NewReader(body)) @@ -213,13 +217,12 @@ func TestAcceptInvitation_NoDB(t *testing.T) { h.Accept(w, r) if w.Code != http.StatusServiceUnavailable { - t.Errorf("Accept without DB: got %d, want %d", w.Code, http.StatusServiceUnavailable) + t.Errorf("Accept without store: got %d, want %d", w.Code, http.StatusServiceUnavailable) } } func TestAcceptInvitation_InvalidBody(t *testing.T) { - // With nil DB, handler returns 503 before parsing body - h := &InvitationsHandler{Store: nil, DB: nil} + h := &InvitationsHandler{Store: nil} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/invitations/inv_abc/accept", strings.NewReader(`{invalid}`)) r.Header.Set("Content-Type", "application/json") @@ -228,7 +231,7 @@ func TestAcceptInvitation_InvalidBody(t *testing.T) { h.Accept(w, r) if w.Code != http.StatusServiceUnavailable { - t.Errorf("Accept with nil DB: got %d, want %d", w.Code, http.StatusServiceUnavailable) + t.Errorf("Accept with nil store: got %d, want %d", w.Code, http.StatusServiceUnavailable) } } @@ -347,8 +350,8 @@ func TestCreateInvitation_CallsMailer(t *testing.T) { store := &mockInvitationStore{inv: inv} ml := &mockMailer{} h := &InvitationsHandler{ - Store: store, - DB: new(pgxpool.Pool), + Store: store, + Mailer: ml, BaseURL: "http://app.example.com", } @@ -386,8 +389,8 @@ func TestCreateInvitation_NilMailer_ReturnsCreated(t *testing.T) { CreatedAt: time.Now(), } h := &InvitationsHandler{ - Store: &mockInvitationStore{inv: inv}, - DB: new(pgxpool.Pool), + Store: &mockInvitationStore{inv: inv}, + Mailer: nil, // no SMTP configured BaseURL: "http://app.example.com", } @@ -417,8 +420,8 @@ func TestCreateInvitation_MailerError_StillReturnsCreated(t *testing.T) { } ml := &mockMailer{err: fmt.Errorf("smtp connection refused")} h := &InvitationsHandler{ - Store: &mockInvitationStore{inv: inv}, - DB: new(pgxpool.Pool), + Store: &mockInvitationStore{inv: inv}, + Mailer: ml, BaseURL: "http://app.example.com", } @@ -462,7 +465,6 @@ func TestCreateInvitation_LogsAuditEvent(t *testing.T) { al := &capAuditLogger{} h := &InvitationsHandler{ Store: ms, - DB: new(pgxpool.Pool), AuditStore: al, } @@ -507,7 +509,6 @@ func TestCreateInvitation_NilAuditStore_NoPanic(t *testing.T) { } h := &InvitationsHandler{ Store: &mockInvitationStore{inv: inv}, - DB: new(pgxpool.Pool), AuditStore: nil, } @@ -644,3 +645,69 @@ func TestAcceptInvitation_LogsAuditEvent(t *testing.T) { t.Errorf("audit actor_id = %q, want %q", e.ActorID, "new-user-1") } } + +func TestPreviewInvitation_WithStore_ReturnsTeamName(t *testing.T) { + inv := &model.Invitation{ + ID: "inv-preview", + TeamID: "team-1", + Email: "invitee@example.com", + Role: "readonly", + InvitedBy: "user-1", + ExpiresAt: time.Now().Add(7 * 24 * time.Hour), + CreatedAt: time.Now(), + } + ms := &mockInvitationStore{ + tokenInv: inv, + teamName: "Alpha Team", + } + h := &InvitationsHandler{Store: ms} + + r := httptest.NewRequest("GET", "/api/v1/invitations/inv_abc", nil) + r = testWithChiParam(r, "token", "inv_abc") + w := httptest.NewRecorder() + + h.Preview(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("Preview: got %d, want %d: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp map[string]interface{} + json.NewDecoder(w.Body).Decode(&resp) + if resp["team_name"] != "Alpha Team" { + t.Errorf("team_name = %v, want %q", resp["team_name"], "Alpha Team") + } +} + +func TestPreviewInvitation_GetTeamNameFallsBackToUnknown(t *testing.T) { + inv := &model.Invitation{ + ID: "inv-preview-2", + TeamID: "team-missing", + Email: "invitee@example.com", + Role: "readonly", + InvitedBy: "user-1", + ExpiresAt: time.Now().Add(7 * 24 * time.Hour), + CreatedAt: time.Now(), + } + ms := &mockInvitationStore{ + tokenInv: inv, + teamNameErr: fmt.Errorf("team not found"), + } + h := &InvitationsHandler{Store: ms} + + r := httptest.NewRequest("GET", "/api/v1/invitations/inv_abc", nil) + r = testWithChiParam(r, "token", "inv_abc") + w := httptest.NewRecorder() + + h.Preview(w, r) + + if w.Code != http.StatusOK { + t.Fatalf("Preview: got %d, want %d: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp map[string]interface{} + json.NewDecoder(w.Body).Decode(&resp) + if resp["team_name"] != "Unknown" { + t.Errorf("team_name = %v, want %q when GetTeamName fails", resp["team_name"], "Unknown") + } +} diff --git a/internal/handler/oauth.go b/internal/handler/oauth.go index 96feed18..ec2973f9 100644 --- a/internal/handler/oauth.go +++ b/internal/handler/oauth.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "net" @@ -15,15 +16,25 @@ import ( "golang.org/x/oauth2" "github.com/scaledtest/scaledtest/internal/auth" - "github.com/scaledtest/scaledtest/internal/db" + "github.com/scaledtest/scaledtest/internal/store" ) +// oauthStore abstracts OAuth persistence operations. +type oauthStore interface { + FindLinkedUser(ctx context.Context, provider, providerID string) (*store.OAuthLinkedUser, error) + FindUserByEmail(ctx context.Context, email string) (*store.OAuthLinkedUser, error) + CreateUser(ctx context.Context, email, displayName string) (userID, role string, err error) + LinkAccount(ctx context.Context, userID, provider, providerID, accessToken, refreshToken string) error + UpdateTokens(ctx context.Context, accessToken, refreshToken, provider, providerID string) error + CreateSession(ctx context.Context, userID, refreshToken, userAgent string, ipAddr net.IP, expiresAt time.Time) error +} + // OAuthHandler handles OAuth 2.0 authentication flows. type OAuthHandler struct { - JWT *auth.JWTManager - DB *db.Pool - OAuth *auth.OAuthConfigs - Secure bool // true if base URL uses HTTPS + JWT *auth.JWTManager + OAuth *auth.OAuthConfigs + OAuthStore oauthStore + Secure bool // true if base URL uses HTTPS } const oauthStateCookie = "oauth_state" @@ -43,7 +54,7 @@ func (h *OAuthHandler) GitHubCallback(w http.ResponseWriter, r *http.Request) { Error(w, http.StatusNotImplemented, "GitHub OAuth is not configured") return } - if h.DB == nil { + if h.OAuthStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } @@ -77,7 +88,7 @@ func (h *OAuthHandler) GoogleCallback(w http.ResponseWriter, r *http.Request) { Error(w, http.StatusNotImplemented, "Google OAuth is not configured") return } - if h.DB == nil { + if h.OAuthStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } @@ -281,64 +292,49 @@ func (h *OAuthHandler) completeOAuth(w http.ResponseWriter, r *http.Request, pro ctx := r.Context() // Check if this OAuth account is already linked - var userID, userEmail, userDisplayName, role string - err := h.DB.QueryRow(ctx, - `SELECT u.id, u.email, u.display_name, u.role - FROM oauth_accounts oa - JOIN users u ON u.id = oa.user_id - WHERE oa.provider = $1 AND oa.provider_id = $2`, - provider, providerID, - ).Scan(&userID, &userEmail, &userDisplayName, &role) - - if err == pgx.ErrNoRows { + user, err := h.OAuthStore.FindLinkedUser(ctx, provider, providerID) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + Error(w, http.StatusInternalServerError, "internal error") + return + } + + if user == nil { // Check if a user with this email already exists (link account) - err = h.DB.QueryRow(ctx, - `SELECT id, email, display_name, role FROM users WHERE email = $1`, email, - ).Scan(&userID, &userEmail, &userDisplayName, &role) + existing, err := h.OAuthStore.FindUserByEmail(ctx, email) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + Error(w, http.StatusInternalServerError, "internal error") + return + } - if err == pgx.ErrNoRows { + if existing == nil { // Create new user - err = h.DB.QueryRow(ctx, - `INSERT INTO users (email, password_hash, display_name) - VALUES ($1, '', $2) - RETURNING id, role`, - email, displayName, - ).Scan(&userID, &role) + userID, role, err := h.OAuthStore.CreateUser(ctx, email, displayName) if err != nil { Error(w, http.StatusInternalServerError, "failed to create user") return } - userEmail = email - userDisplayName = displayName - } else if err != nil { - Error(w, http.StatusInternalServerError, "internal error") - return + user = &store.OAuthLinkedUser{ + ID: userID, + Email: email, + DisplayName: displayName, + Role: role, + } + } else { + user = existing } // Link OAuth account - _, err = h.DB.Exec(ctx, - `INSERT INTO oauth_accounts (user_id, provider, provider_id, access_token, refresh_token) - VALUES ($1, $2, $3, $4, $5)`, - userID, provider, providerID, token.AccessToken, token.RefreshToken, - ) - if err != nil { + if err := h.OAuthStore.LinkAccount(ctx, user.ID, provider, providerID, token.AccessToken, token.RefreshToken); err != nil { Error(w, http.StatusInternalServerError, "failed to link OAuth account") return } - } else if err != nil { - Error(w, http.StatusInternalServerError, "internal error") - return } else { // Update stored tokens - _, _ = h.DB.Exec(ctx, - `UPDATE oauth_accounts SET access_token = $1, refresh_token = $2 - WHERE provider = $3 AND provider_id = $4`, - token.AccessToken, token.RefreshToken, provider, providerID, - ) + _ = h.OAuthStore.UpdateTokens(ctx, token.AccessToken, token.RefreshToken, provider, providerID) } // Issue JWT tokens - pair, err := h.JWT.GenerateTokenPair(userID, userEmail, role, "") + pair, err := h.JWT.GenerateTokenPair(user.ID, user.Email, user.Role, "") if err != nil { Error(w, http.StatusInternalServerError, "failed to generate tokens") return @@ -352,12 +348,7 @@ func (h *OAuthHandler) completeOAuth(w http.ResponseWriter, r *http.Request, pro } expiresAt := time.Now().Add(h.JWT.RefreshDuration()) - _, err = h.DB.Exec(ctx, - `INSERT INTO sessions (user_id, refresh_token, user_agent, ip_address, expires_at) - VALUES ($1, $2, $3, $4, $5)`, - userID, pair.RefreshToken, r.UserAgent(), ipAddr, expiresAt, - ) - if err != nil { + if err := h.OAuthStore.CreateSession(ctx, user.ID, pair.RefreshToken, r.UserAgent(), ipAddr, expiresAt); err != nil { Error(w, http.StatusInternalServerError, "failed to create session") return } @@ -366,10 +357,10 @@ func (h *OAuthHandler) completeOAuth(w http.ResponseWriter, r *http.Request, pro JSON(w, http.StatusOK, AuthResponse{ User: UserResponse{ - ID: userID, - Email: userEmail, - DisplayName: userDisplayName, - Role: role, + ID: user.ID, + Email: user.Email, + DisplayName: user.DisplayName, + Role: user.Role, }, AccessToken: pair.AccessToken, ExpiresAt: pair.ExpiresAt, diff --git a/internal/handler/oauth_test.go b/internal/handler/oauth_test.go index 52947041..d45e06a3 100644 --- a/internal/handler/oauth_test.go +++ b/internal/handler/oauth_test.go @@ -13,7 +13,7 @@ import ( func newTestOAuthHandler(oauthCfgs *auth.OAuthConfigs) *OAuthHandler { jwt := auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour) - return &OAuthHandler{JWT: jwt, DB: nil, OAuth: oauthCfgs, Secure: false} + return &OAuthHandler{JWT: jwt, OAuthStore: nil, OAuth: oauthCfgs, Secure: false} } func TestGitHubLogin_NotConfigured(t *testing.T) { @@ -145,7 +145,7 @@ func TestGoogleCallback_NotConfigured(t *testing.T) { } } -func TestGitHubCallback_NoDB(t *testing.T) { +func TestGitHubCallback_NoStore(t *testing.T) { cfg := &auth.OAuthConfigs{ GitHub: &oauth2.Config{ ClientID: "test-id", @@ -156,14 +156,14 @@ func TestGitHubCallback_NoDB(t *testing.T) { }, }, } - h := newTestOAuthHandler(cfg) // DB is nil + h := newTestOAuthHandler(cfg) // OAuthStore is nil req := httptest.NewRequest("GET", "/auth/github/callback?code=abc&state=xyz", nil) w := httptest.NewRecorder() h.GitHubCallback(w, req) if w.Code != http.StatusServiceUnavailable { - t.Errorf("GitHubCallback no DB: got %d, want %d", w.Code, http.StatusServiceUnavailable) + t.Errorf("GitHubCallback no store: got %d, want %d", w.Code, http.StatusServiceUnavailable) } } @@ -178,20 +178,20 @@ func TestGitHubCallback_MissingState(t *testing.T) { }, }, } - // Need a non-nil DB to get past the DB check + // Need a non-nil OAuthStore to get past the store check h := &OAuthHandler{ - JWT: auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour), - DB: nil, // We'll test the state check, but DB nil check comes first - OAuth: cfg, + JWT: auth.NewJWTManager(testSecret, 15*time.Minute, 7*24*time.Hour), + OAuthStore: nil, // We'll test the state check, but store nil check comes first + OAuth: cfg, } - // With nil DB, we'll get 503 before state check. That's OK for this path. + // With nil OAuthStore, we'll get 503 before state check. That's OK for this path. // The state validation test is implicit in the callback_NoDB test. req := httptest.NewRequest("GET", "/auth/github/callback?code=abc", nil) w := httptest.NewRecorder() h.GitHubCallback(w, req) - // Should get 503 (no DB) since DB check happens before state check + // Should get 503 (no store) since store check happens before state check if w.Code != http.StatusServiceUnavailable { t.Errorf("GitHubCallback missing state: got %d, want %d", w.Code, http.StatusServiceUnavailable) } @@ -208,7 +208,7 @@ func TestGitHubCallback_ProviderError(t *testing.T) { }, }, } - // DB nil → 503 before provider error check. This tests the not-configured path. + // OAuthStore nil → 503 before provider error check. This tests the not-configured path. h := newTestOAuthHandler(cfg) req := httptest.NewRequest("GET", "/auth/github/callback?error=access_denied&state=xyz", nil) w := httptest.NewRecorder() diff --git a/internal/handler/quality_gates.go b/internal/handler/quality_gates.go index 4e7190d3..695ab7a6 100644 --- a/internal/handler/quality_gates.go +++ b/internal/handler/quality_gates.go @@ -10,7 +10,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/scaledtest/scaledtest/internal/auth" - "github.com/scaledtest/scaledtest/internal/db" + "github.com/scaledtest/scaledtest/internal/model" "github.com/scaledtest/scaledtest/internal/quality" "github.com/scaledtest/scaledtest/internal/sanitize" @@ -54,9 +54,9 @@ type qualityGateStore interface { // QualityGatesHandler handles quality gate endpoints. type QualityGatesHandler struct { - Store qualityGateStore - DB *db.Pool - AuditStore auditLogger + Store qualityGateStore + ReportStore reportsStore + AuditStore auditLogger } // CreateQualityGateRequest is the request body for creating a quality gate. @@ -121,6 +121,17 @@ func teamIDFromURL(w http.ResponseWriter, r *http.Request, claims *auth.Claims) return teamID, true } +// gateIDFromURL extracts the gateID URL parameter. +// Returns the gateID or writes an error. +func gateIDFromURL(w http.ResponseWriter, r *http.Request) (string, bool) { + gateID := chi.URLParam(r, "gateID") + if gateID == "" { + Error(w, http.StatusBadRequest, "missing gate ID") + return "", false + } + return gateID, true +} + // List handles GET /api/v1/teams/:teamID/quality-gates. func (h *QualityGatesHandler) List(w http.ResponseWriter, r *http.Request) { claims := auth.GetClaims(r.Context()) @@ -193,7 +204,6 @@ func (h *QualityGatesHandler) Create(w http.ResponseWriter, r *http.Request) { return } - // Sanitize user-provided strings req.Name = sanitize.String(req.Name) req.Description = sanitize.String(req.Description) @@ -229,9 +239,8 @@ func (h *QualityGatesHandler) Get(w http.ResponseWriter, r *http.Request) { return } - gateID := chi.URLParam(r, "gateID") - if gateID == "" { - Error(w, http.StatusBadRequest, "missing gate ID") + gateID, ok := gateIDFromURL(w, r) + if !ok { return } @@ -266,9 +275,8 @@ func (h *QualityGatesHandler) Update(w http.ResponseWriter, r *http.Request) { return } - gateID := chi.URLParam(r, "gateID") - if gateID == "" { - Error(w, http.StatusBadRequest, "missing gate ID") + gateID, ok := gateIDFromURL(w, r) + if !ok { return } @@ -288,7 +296,6 @@ func (h *QualityGatesHandler) Update(w http.ResponseWriter, r *http.Request) { return } - // Sanitize user-provided strings req.Name = sanitize.String(req.Name) req.Description = sanitize.String(req.Description) @@ -333,9 +340,8 @@ func (h *QualityGatesHandler) Delete(w http.ResponseWriter, r *http.Request) { return } - gateID := chi.URLParam(r, "gateID") - if gateID == "" { - Error(w, http.StatusBadRequest, "missing gate ID") + gateID, ok := gateIDFromURL(w, r) + if !ok { return } @@ -379,13 +385,12 @@ func (h *QualityGatesHandler) Evaluate(w http.ResponseWriter, r *http.Request) { return } - gateID := chi.URLParam(r, "gateID") - if gateID == "" { - Error(w, http.StatusBadRequest, "missing gate ID") + gateID, ok := gateIDFromURL(w, r) + if !ok { return } - if h.Store == nil || h.DB == nil { + if h.Store == nil || h.ReportStore == nil { Error(w, http.StatusNotImplemented, "evaluate requires database connection") return } @@ -400,18 +405,13 @@ func (h *QualityGatesHandler) Evaluate(w http.ResponseWriter, r *http.Request) { return } - // Get the gate to access its rules gate, err := h.Store.Get(r.Context(), teamID, gateID) if err != nil { Error(w, http.StatusNotFound, "quality gate not found") return } - // Load report summary from DB - var summaryJSON json.RawMessage - err = h.DB.QueryRow(r.Context(), - `SELECT summary FROM test_reports WHERE id = $1 AND team_id = $2`, - req.ReportID, teamID).Scan(&summaryJSON) + rpt, testResults, err := h.ReportStore.GetReportAndResults(r.Context(), req.ReportID, teamID) if err != nil { Error(w, http.StatusNotFound, "report not found") return @@ -423,60 +423,22 @@ func (h *QualityGatesHandler) Evaluate(w http.ResponseWriter, r *http.Request) { Failed int `json:"failed"` Skipped int `json:"skipped"` } - if err := json.Unmarshal(summaryJSON, &summary); err != nil { + if err := json.Unmarshal(rpt.Summary, &summary); err != nil { Error(w, http.StatusInternalServerError, "failed to parse report summary") return } - // Load test results for duration and flaky data - rows, err := h.DB.Query(r.Context(), - `SELECT name, status, duration_ms, flaky, suite, file_path - FROM test_results WHERE report_id = $1 AND team_id = $2`, - req.ReportID, teamID) - if err != nil { - Error(w, http.StatusInternalServerError, "failed to query test results") - return - } - defer rows.Close() - var totalDurationMs int64 currentFailed := make(map[string]bool) - var flakyTests []struct { - name, suite, filePath string - } - - for rows.Next() { - var name, status string - var suite, filePath *string - var durationMs int64 - var flaky bool - if err := rows.Scan(&name, &status, &durationMs, &flaky, &suite, &filePath); err != nil { - Error(w, http.StatusInternalServerError, "failed to scan test result") - return - } - totalDurationMs += durationMs - if status == "failed" { - currentFailed[name] = true - } - if flaky { - s, fp := "", "" - if suite != nil { - s = *suite - } - if filePath != nil { - fp = *filePath - } - flakyTests = append(flakyTests, struct { - name, suite, filePath string - }{name, s, fp}) + + for _, res := range testResults { + totalDurationMs += res.DurationMs + if res.Status == "failed" { + currentFailed[res.Name] = true } } - if err := rows.Err(); err != nil { - Error(w, http.StatusInternalServerError, "failed to iterate test results") - return - } - previousFailed, err := fetchPreviousFailedTests(r.Context(), h.DB, teamID, req.ReportID) + previousFailed, err := h.ReportStore.GetPreviousFailedTests(r.Context(), teamID, req.ReportID) if err != nil { Error(w, http.StatusInternalServerError, "failed to fetch previous failures") return @@ -499,7 +461,6 @@ func (h *QualityGatesHandler) Evaluate(w http.ResponseWriter, r *http.Request) { return } - // Store evaluation result detailsJSON, _ := json.Marshal(evalResult.Results) eval, err := h.Store.CreateEvaluation(r.Context(), gateID, req.ReportID, evalResult.Passed, detailsJSON) if err != nil { @@ -507,7 +468,6 @@ func (h *QualityGatesHandler) Evaluate(w http.ResponseWriter, r *http.Request) { return } - // Build response rules := make([]QualityGateRuleResult, len(evalResult.Results)) for i, rr := range evalResult.Results { rules[i] = QualityGateRuleResult{ @@ -541,9 +501,8 @@ func (h *QualityGatesHandler) ListEvaluations(w http.ResponseWriter, r *http.Req return } - gateID := chi.URLParam(r, "gateID") - if gateID == "" { - Error(w, http.StatusBadRequest, "missing gate ID") + gateID, ok := gateIDFromURL(w, r) + if !ok { return } diff --git a/internal/handler/quality_gates_test.go b/internal/handler/quality_gates_test.go index 459df289..d2d35752 100644 --- a/internal/handler/quality_gates_test.go +++ b/internal/handler/quality_gates_test.go @@ -332,7 +332,7 @@ func TestQualityGatesDeleteReadonlyForbidden(t *testing.T) { } func TestQualityGatesEvaluateWithoutDB(t *testing.T) { - h := &QualityGatesHandler{Store: nil, DB: nil} + h := &QualityGatesHandler{Store: nil, ReportStore: nil} body := `{"report_id":"report-123"}` req := httptest.NewRequest("POST", "/api/v1/teams/team-1/quality-gates/gate-1/evaluate", strings.NewReader(body)) @@ -350,7 +350,7 @@ func TestQualityGatesEvaluateWithoutDB(t *testing.T) { } func TestQualityGatesEvaluateMissingReportID(t *testing.T) { - h := &QualityGatesHandler{Store: nil, DB: nil} + h := &QualityGatesHandler{Store: nil, ReportStore: nil} body := `{}` req := httptest.NewRequest("POST", "/api/v1/teams/team-1/quality-gates/gate-1/evaluate", strings.NewReader(body)) @@ -370,7 +370,7 @@ func TestQualityGatesEvaluateMissingReportID(t *testing.T) { } func TestQualityGatesEvaluateUnauthorized(t *testing.T) { - h := &QualityGatesHandler{Store: nil, DB: nil} + h := &QualityGatesHandler{Store: nil, ReportStore: nil} body := `{"report_id":"report-123"}` req := httptest.NewRequest("POST", "/api/v1/teams/team-1/quality-gates/gate-1/evaluate", strings.NewReader(body)) @@ -387,7 +387,7 @@ func TestQualityGatesEvaluateUnauthorized(t *testing.T) { } func TestQualityGatesEvaluateMissingGateID(t *testing.T) { - h := &QualityGatesHandler{Store: nil, DB: nil} + h := &QualityGatesHandler{Store: nil, ReportStore: nil} body := `{"report_id":"report-123"}` req := httptest.NewRequest("POST", "/api/v1/teams/team-1/quality-gates//evaluate", strings.NewReader(body)) @@ -465,7 +465,7 @@ func TestValidateRules(t *testing.T) { // mockQGStore implements qualityGateStore for audit logging tests. type mockQGStore struct { - gate *model.QualityGate + gate *model.QualityGate delErr error } diff --git a/internal/handler/reports.go b/internal/handler/reports.go index 7e75a125..3f67b813 100644 --- a/internal/handler/reports.go +++ b/internal/handler/reports.go @@ -18,7 +18,7 @@ import ( "github.com/scaledtest/scaledtest/internal/analytics" "github.com/scaledtest/scaledtest/internal/auth" "github.com/scaledtest/scaledtest/internal/ctrf" - "github.com/scaledtest/scaledtest/internal/db" + "github.com/scaledtest/scaledtest/internal/model" "github.com/scaledtest/scaledtest/internal/quality" "github.com/scaledtest/scaledtest/internal/store" @@ -49,23 +49,15 @@ type githubStatusPoster interface { // ReportsHandler handles CTRF report endpoints. type ReportsHandler struct { - DB *db.Pool + ReportStore reportsStore AuditStore *store.AuditStore QualityGateStore qualityGateEvaluator Webhooks *webhook.Notifier - GitHubStatusPoster githubStatusPoster // nil when GitHub integration is disabled - BaseURL string // used to construct target URLs in GitHub statuses - // TriageStore provides access to persisted triage results for read and retry. - // When nil, triage endpoints return 503. - TriageStore triageAccessor - // TriageEnqueuer schedules background LLM triage for each ingested report. - // When nil, triage is disabled (e.g. no LLM credentials configured). - TriageEnqueuer triage.Enqueuer - // AllowBackdate permits callers to supply a ?created_at= query - // parameter to override the report ingestion timestamp. This must only be - // enabled in controlled test environments (e.g. when ST_DISABLE_RATE_LIMIT - // is true) — never in production. - AllowBackdate bool + GitHubStatusPoster githubStatusPoster + BaseURL string + TriageStore triageAccessor + TriageEnqueuer triage.Enqueuer + AllowBackdate bool } // List handles GET /api/v1/reports. @@ -76,8 +68,8 @@ func (h *ReportsHandler) List(w http.ResponseWriter, r *http.Request) { return } - // Validate date filters before the DB check so malformed params return 400 - // rather than falling through to a DB error. + // Validate date filters before the store check so malformed params return 400 + // rather than falling through to a store error. var sinceTime, untilTime time.Time var hasSince, hasUntil bool @@ -100,69 +92,33 @@ func (h *ReportsHandler) List(w http.ResponseWriter, r *http.Request) { hasUntil = true } - if h.DB == nil { + if h.ReportStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } limit, offset := parsePagination(r) - // Build WHERE clause with optional date filters - whereClause := ` WHERE team_id = $1` - args := []interface{}{claims.TeamID} - argIdx := 2 - + var sincePtr, untilPtr *time.Time if hasSince { - whereClause += ` AND created_at >= $` + strconv.Itoa(argIdx) - args = append(args, sinceTime) - argIdx++ + sincePtr = &sinceTime } if hasUntil { - whereClause += ` AND created_at <= $` + strconv.Itoa(argIdx) - args = append(args, untilTime) - argIdx++ - } - - // Count query uses the same WHERE clause (including since/until filters) - countQuery := `SELECT COUNT(*) FROM test_reports` + whereClause - var total int - if err := h.DB.QueryRow(r.Context(), countQuery, args...).Scan(&total); err != nil { - Error(w, http.StatusInternalServerError, "failed to count reports") - return - } - - // Data query - query := `SELECT id, team_id, execution_id, tool_name, tool_version, environment, summary, created_at - FROM test_reports` + whereClause + - ` ORDER BY created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1) - dataArgs := append(args, limit, offset) - - rows, err := h.DB.Query(r.Context(), query, dataArgs...) + untilPtr = &untilTime + } + reports, total, err := h.ReportStore.List(r.Context(), store.ReportListFilter{ + TeamID: claims.TeamID, + Since: sincePtr, + Until: untilPtr, + Limit: limit, + Offset: offset, + }) if err != nil { Error(w, http.StatusInternalServerError, "failed to query reports") return } - defer rows.Close() - - flatReports := make([]map[string]interface{}, 0) - for rows.Next() { - var rpt model.TestReport - if err := rows.Scan( - &rpt.ID, &rpt.TeamID, &rpt.ExecutionID, &rpt.ToolName, - &rpt.ToolVersion, &rpt.Environment, &rpt.Summary, &rpt.CreatedAt, - ); err != nil { - Error(w, http.StatusInternalServerError, "failed to scan report") - return - } - flatReports = append(flatReports, flattenReportForList(rpt)) - } - if err := rows.Err(); err != nil { - Error(w, http.StatusInternalServerError, "failed to iterate reports") - return - } - JSON(w, http.StatusOK, map[string]interface{}{ - "reports": flatReports, + "reports": reports, "total": total, }) } @@ -192,45 +148,27 @@ func (h *ReportsHandler) Create(w http.ResponseWriter, r *http.Request) { return } - // Sanitize all user-controlled string fields to prevent stored XSS ctrf.Sanitize(report) executionID := r.URL.Query().Get("execution_id") triageGitHubStatus := r.URL.Query().Get("triage_github_status") == "true" - if h.DB == nil { - // Fallback for no-DB mode: accept but don't persist - resp := map[string]interface{}{ - "message": "report accepted", - "tool": report.Results.Tool.Name, - "tests": report.Results.Summary.Tests, - } - if executionID != "" { - resp["execution_id"] = executionID - } - if triageGitHubStatus { - resp["triage_github_status"] = true - } - JSON(w, http.StatusCreated, resp) - h.maybePostGitHubStatus(r, report.Results.Summary, "", executionID) + if h.ReportStore == nil { + Error(w, http.StatusServiceUnavailable, "database not configured") return } reportID := uuid.New().String() now := h.resolveReportTime(r) - // Validate execution_id as UUID and verify team ownership if provided var execIDPtr *string if executionID != "" { if _, err := uuid.Parse(executionID); err != nil { Error(w, http.StatusBadRequest, "invalid execution_id: must be a valid UUID") return } - // Verify execution exists and belongs to this team var exists bool - err := h.DB.QueryRow(r.Context(), - `SELECT EXISTS(SELECT 1 FROM test_executions WHERE id = $1 AND team_id = $2)`, - executionID, claims.TeamID).Scan(&exists) + exists, err = h.ReportStore.ExecutionExists(r.Context(), executionID, claims.TeamID) if err != nil { Error(w, http.StatusInternalServerError, "failed to verify execution") return @@ -242,77 +180,36 @@ func (h *ReportsHandler) Create(w http.ResponseWriter, r *http.Request) { execIDPtr = &executionID } - // Build summary JSON summaryJSON, err := ctrf.SummaryJSON(report.Results.Summary) if err != nil { Error(w, http.StatusInternalServerError, "failed to marshal summary") return } - // Store raw CTRF for archival - rawJSON := json.RawMessage(body) - - // Use a transaction for atomic report ingestion - tx, err := h.DB.Begin(r.Context()) - if err != nil { - Error(w, http.StatusInternalServerError, "failed to begin transaction") - return - } - defer tx.Rollback(r.Context()) - - // Insert report - _, err = tx.Exec(r.Context(), - `INSERT INTO test_reports (id, team_id, execution_id, tool_name, tool_version, environment, summary, raw, created_at, triage_github_status) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`, - reportID, claims.TeamID, execIDPtr, - report.Results.Tool.Name, report.Results.Tool.Version, - report.Results.Environment, summaryJSON, rawJSON, now, - triageGitHubStatus) - if err != nil { - Error(w, http.StatusInternalServerError, "failed to store report") - return - } - - // Normalize and insert individual test results results := ctrf.Normalize(report, reportID, claims.TeamID) - for _, res := range results { - resID := uuid.New().String() - _, err = tx.Exec(r.Context(), - `INSERT INTO test_results (id, report_id, team_id, name, status, duration_ms, message, trace, file_path, suite, tags, retry, flaky, created_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)`, - resID, res.ReportID, res.TeamID, res.Name, res.Status, - res.DurationMs, nullString(res.Message), nullString(res.Trace), - nullString(res.FilePath), nullString(res.Suite), - res.Tags, res.Retry, res.Flaky, now) - if err != nil { - Error(w, http.StatusInternalServerError, "failed to store test result") - return - } - } - // If linked to an execution, update execution with report_id - if execIDPtr != nil { - tag, err := tx.Exec(r.Context(), - `UPDATE test_executions SET report_id = $1, updated_at = $2 - WHERE id = $3 AND team_id = $4`, - reportID, now, executionID, claims.TeamID) - if err != nil { - Error(w, http.StatusInternalServerError, "failed to link report to execution") - return - } - if tag.RowsAffected() == 0 { + rawJSON := json.RawMessage(body) + params := store.CreateReportParams{ + ID: reportID, + TeamID: claims.TeamID, + ExecutionID: execIDPtr, + ToolName: report.Results.Tool.Name, + ToolVersion: report.Results.Tool.Version, + Environment: report.Results.Environment, + Summary: summaryJSON, + Raw: rawJSON, + CreatedAt: now, + TriageGitHubStatus: triageGitHubStatus, + } + if err := h.ReportStore.CreateWithResults(r.Context(), params, results); err != nil { + if err == pgx.ErrNoRows { Error(w, http.StatusBadRequest, "execution not found or not in team") return } - } - - if err := tx.Commit(r.Context()); err != nil { - Error(w, http.StatusInternalServerError, "failed to commit report") + Error(w, http.StatusInternalServerError, "failed to store report") return } - // Enqueue async triage — non-blocking, best-effort. Must be called after - // the transaction commits so the triage job can read the persisted rows. if h.TriageEnqueuer != nil { h.TriageEnqueuer.Enqueue(claims.TeamID, reportID) } @@ -331,7 +228,6 @@ func (h *ReportsHandler) Create(w http.ResponseWriter, r *http.Request) { resp["triage_github_status"] = true } - // Evaluate quality gates for this team if h.QualityGateStore != nil { gateResult := h.evaluateQualityGates(r, claims.TeamID, reportID, report, results) if gateResult != nil { @@ -358,7 +254,6 @@ func (h *ReportsHandler) Create(w http.ResponseWriter, r *http.Request) { }) } - // Fire webhook: report.submitted h.Webhooks.Notify(claims.TeamID, webhook.EventReportSubmitted, map[string]interface{}{ "report_id": reportID, "tool": report.Results.Tool.Name, @@ -368,7 +263,6 @@ func (h *ReportsHandler) Create(w http.ResponseWriter, r *http.Request) { "failed": report.Results.Summary.Failed, }) - // Fire webhook: gate.failed if any quality gate failed if gateResult, ok := resp["qualityGate"].(*QualityGateResponse); ok && gateResult != nil && !gateResult.Passed { h.Webhooks.Notify(claims.TeamID, webhook.EventGateFailed, map[string]interface{}{ "report_id": reportID, @@ -394,20 +288,12 @@ func (h *ReportsHandler) Get(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.ReportStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } - var rpt model.TestReport - err := h.DB.QueryRow(r.Context(), - `SELECT id, team_id, execution_id, tool_name, tool_version, environment, summary, created_at - FROM test_reports - WHERE id = $1 AND team_id = $2`, - reportID, claims.TeamID).Scan( - &rpt.ID, &rpt.TeamID, &rpt.ExecutionID, &rpt.ToolName, - &rpt.ToolVersion, &rpt.Environment, &rpt.Summary, &rpt.CreatedAt, - ) + rpt, err := h.ReportStore.Get(r.Context(), reportID, claims.TeamID) if err == pgx.ErrNoRows { Error(w, http.StatusNotFound, "report not found") return @@ -416,8 +302,7 @@ func (h *ReportsHandler) Get(w http.ResponseWriter, r *http.Request) { Error(w, http.StatusInternalServerError, "failed to get report") return } - - JSON(w, http.StatusOK, buildGetReportResponse(rpt)) + JSON(w, http.StatusOK, buildGetReportResponse(*rpt)) } // Delete handles DELETE /api/v1/reports/{reportID}. @@ -434,19 +319,17 @@ func (h *ReportsHandler) Delete(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.ReportStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } - tag, err := h.DB.Exec(r.Context(), - `DELETE FROM test_reports WHERE id = $1 AND team_id = $2`, - reportID, claims.TeamID) + rowsAffected, err := h.ReportStore.Delete(r.Context(), reportID, claims.TeamID) if err != nil { Error(w, http.StatusInternalServerError, "failed to delete report") return } - if tag.RowsAffected() == 0 { + if rowsAffected == 0 { Error(w, http.StatusNotFound, "report not found") return } @@ -489,25 +372,16 @@ func (h *ReportsHandler) Compare(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.ReportStore == nil { Error(w, http.StatusServiceUnavailable, "database not configured") return } - // Fetch both report metadata in parallel (sequential for simplicity, both must belong to team) - fetchReport := func(id string) (*model.TestReport, error) { - var rpt model.TestReport - err := h.DB.QueryRow(r.Context(), - `SELECT id, team_id, execution_id, tool_name, tool_version, summary, created_at - FROM test_reports WHERE id = $1 AND team_id = $2`, - id, claims.TeamID).Scan( - &rpt.ID, &rpt.TeamID, &rpt.ExecutionID, &rpt.ToolName, - &rpt.ToolVersion, &rpt.Summary, &rpt.CreatedAt, - ) - return &rpt, err - } + var baseReport, headReport *model.TestReport + var baseResults, headResults map[string]*model.TestResult - baseReport, err := fetchReport(baseID) + var err error + baseReport, baseResults, err = h.ReportStore.GetReportAndResults(r.Context(), baseID, claims.TeamID) if err == pgx.ErrNoRows { Error(w, http.StatusNotFound, "base report not found") return @@ -516,8 +390,7 @@ func (h *ReportsHandler) Compare(w http.ResponseWriter, r *http.Request) { Error(w, http.StatusInternalServerError, "failed to fetch base report") return } - - headReport, err := fetchReport(headID) + headReport, headResults, err = h.ReportStore.GetReportAndResults(r.Context(), headID, claims.TeamID) if err == pgx.ErrNoRows { Error(w, http.StatusNotFound, "head report not found") return @@ -527,46 +400,6 @@ func (h *ReportsHandler) Compare(w http.ResponseWriter, r *http.Request) { return } - // Fetch test results for both reports - fetchResults := func(reportID string) (map[string]*model.TestResult, error) { - rows, err := h.DB.Query(r.Context(), - `SELECT id, report_id, team_id, name, status, duration_ms, - COALESCE(message, ''), COALESCE(trace, ''), COALESCE(file_path, ''), COALESCE(suite, ''), - tags, retry, flaky, created_at - FROM test_results WHERE report_id = $1 AND team_id = $2`, - reportID, claims.TeamID) - if err != nil { - return nil, err - } - defer rows.Close() - - results := make(map[string]*model.TestResult) - for rows.Next() { - var res model.TestResult - if err := rows.Scan( - &res.ID, &res.ReportID, &res.TeamID, &res.Name, &res.Status, - &res.DurationMs, &res.Message, &res.Trace, &res.FilePath, - &res.Suite, &res.Tags, &res.Retry, &res.Flaky, &res.CreatedAt, - ); err != nil { - return nil, err - } - results[res.Name] = &res - } - return results, rows.Err() - } - - baseResults, err := fetchResults(baseID) - if err != nil { - Error(w, http.StatusInternalServerError, "failed to fetch base test results") - return - } - - headResults, err := fetchResults(headID) - if err != nil { - Error(w, http.StatusInternalServerError, "failed to fetch head test results") - return - } - // Compute diff type TestDiff struct { Name string `json:"name"` @@ -585,11 +418,9 @@ func (h *ReportsHandler) Compare(w http.ResponseWriter, r *http.Request) { var fixed []TestDiff var durationRegressions []TestDiff - // Tests in head: compare against base for name, headRes := range headResults { baseRes, existed := baseResults[name] if !existed { - // New test — only flag if it failed if headRes.Status == "failed" { newFailures = append(newFailures, TestDiff{ Name: name, @@ -603,7 +434,6 @@ func (h *ReportsHandler) Compare(w http.ResponseWriter, r *http.Request) { continue } - // Status changes if baseRes.Status != "failed" && headRes.Status == "failed" { newFailures = append(newFailures, TestDiff{ Name: name, @@ -627,7 +457,6 @@ func (h *ReportsHandler) Compare(w http.ResponseWriter, r *http.Request) { }) } - // Duration regression: >20% slower AND at least 100ms longer if baseRes.DurationMs > 0 { delta := headRes.DurationMs - baseRes.DurationMs pct := float64(delta) / float64(baseRes.DurationMs) * 100 @@ -647,9 +476,6 @@ func (h *ReportsHandler) Compare(w http.ResponseWriter, r *http.Request) { } } - // Tests that existed in base but are gone from head (treat as removed, not failure) - // No action needed per spec — just track new failures / fixed. - type DiffSummary struct { BaseTests int `json:"base_tests"` HeadTests int `json:"head_tests"` @@ -716,7 +542,6 @@ func (h *ReportsHandler) GetTriage(w http.ResponseWriter, r *http.Request) { return } - // Index classifications by cluster ID for O(1) look-up when building output. classByCluster := make(map[string][]map[string]string) for _, c := range classifications { key := "" @@ -795,13 +620,11 @@ func (h *ReportsHandler) RetryTriage(w http.ResponseWriter, r *http.Request) { return } - // Prevent destructive reset when no job can be enqueued to regenerate the data. if h.TriageEnqueuer == nil { Error(w, http.StatusServiceUnavailable, "triage not available") return } - // Reset from complete or failed back to pending. resetResult, err := h.TriageStore.ForceReset(r.Context(), claims.TeamID, reportID) if err != nil { log.Error().Err(err).Str("report_id", reportID).Msg("failed to reset triage for retry") @@ -931,14 +754,6 @@ func (h *ReportsHandler) resolveReportTime(r *http.Request) time.Time { return time.Now() } -// nullString returns a *string that is nil for empty strings. -func nullString(s string) *string { - if s == "" { - return nil - } - return &s -} - // parsePagination extracts limit and offset from query parameters. func parsePagination(r *http.Request) (int, int) { limit := 50 @@ -969,8 +784,8 @@ type QualityGateRuleResult struct { // QualityGateResponse is the quality gate section of the report submission response. type QualityGateResponse struct { - Passed bool `json:"passed"` - Gates []QualityGateDetail `json:"gates"` + Passed bool `json:"passed"` + Gates []QualityGateDetail `json:"gates"` } // QualityGateDetail is a single gate's evaluation in the response. @@ -999,7 +814,7 @@ func (h *ReportsHandler) evaluateQualityGates( return nil } - previousFailed, prevErr := fetchPreviousFailedTests(r.Context(), h.DB, teamID, reportID) + previousFailed, prevErr := h.getPreviousFailedTests(r.Context(), teamID, reportID) if prevErr != nil { log.Warn().Err(prevErr).Str("team_id", teamID).Str("report_id", reportID). Msg("failed to fetch previous failures for quality gate evaluation; skipping gate evaluation") @@ -1015,7 +830,6 @@ func (h *ReportsHandler) evaluateQualityGates( continue } - // Store evaluation in DB detailsJSON, _ := json.Marshal(evalResult.Results) _, storeErr := h.QualityGateStore.CreateEvaluation( r.Context(), gate.ID, reportID, evalResult.Passed, detailsJSON, @@ -1024,7 +838,6 @@ func (h *ReportsHandler) evaluateQualityGates( log.Error().Err(storeErr).Str("gate_id", gate.ID).Msg("failed to store gate evaluation") } - // Build response detail rules := make([]QualityGateRuleResult, len(evalResult.Results)) for i, rr := range evalResult.Results { rules[i] = QualityGateRuleResult{ @@ -1051,52 +864,11 @@ func (h *ReportsHandler) evaluateQualityGates( return gateResp } -// fetchPreviousFailedTests returns the set of failed test names from the most -// recent prior report for the given team (excluding currentReportID). Returns -// (nil, nil) if no prior report exists or the prior report had no failures. -// Returns a non-nil error on DB errors so callers can distinguish transient -// failures from the legitimate "no baseline" case. -func fetchPreviousFailedTests(ctx context.Context, pool *db.Pool, teamID, currentReportID string) (map[string]bool, error) { - if pool == nil { - return nil, nil - } - - var prevReportID string - err := pool.QueryRow(ctx, - `SELECT id FROM test_reports WHERE team_id = $1 AND id != $2 ORDER BY created_at DESC LIMIT 1`, - teamID, currentReportID, - ).Scan(&prevReportID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return nil, nil // no prior report — not an error - } - return nil, fmt.Errorf("fetch previous report: %w", err) - } - - rows, err := pool.Query(ctx, - `SELECT name FROM test_results WHERE report_id = $1 AND status = 'failed'`, - prevReportID, - ) - if err != nil { - return nil, fmt.Errorf("fetch previous failures: %w", err) - } - defer rows.Close() - - failed := make(map[string]bool) - for rows.Next() { - var name string - if err := rows.Scan(&name); err != nil { - return nil, fmt.Errorf("scan previous failure: %w", err) - } - failed[name] = true - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("iterate previous failures: %w", err) - } - if len(failed) == 0 { +func (h *ReportsHandler) getPreviousFailedTests(ctx context.Context, teamID, currentReportID string) (map[string]bool, error) { + if h.ReportStore == nil { return nil, nil } - return failed, nil + return h.ReportStore.GetPreviousFailedTests(ctx, teamID, currentReportID) } // buildReportData constructs quality.ReportData from a CTRF report, its @@ -1178,4 +950,3 @@ func (h *ReportsHandler) maybePostGitHubStatus(r *http.Request, summary ctrf.Sum } }() } - diff --git a/internal/handler/reports_bench_test.go b/internal/handler/reports_bench_test.go new file mode 100644 index 00000000..4a6bc14f --- /dev/null +++ b/internal/handler/reports_bench_test.go @@ -0,0 +1,135 @@ +package handler + +import ( + "testing" + + "github.com/scaledtest/scaledtest/internal/ctrf" + "github.com/scaledtest/scaledtest/internal/model" +) + +// BenchmarkNormalize_100Results measures CTRF normalization for 100 test results. +// This benchmarks the in-memory processing that happens before bulk-insert, +// confirming the store layer avoids N+1 round-trips. +func BenchmarkNormalize_100Results(b *testing.B) { + tests := make([]ctrf.Test, 100) + for i := range tests { + tests[i] = ctrf.Test{ + Name: "benchmark-test-name-that-is-reasonably-long", + Status: "passed", + Duration: 100, + Message: "benchmark message", + Trace: "at benchmark.go:42", + FilePath: "benchmark/path/to/test.go", + Suite: "BenchmarkSuite", + Tags: []string{"smoke", "benchmark"}, + Retry: 0, + Flaky: false, + } + } + report := &ctrf.Report{ + Results: ctrf.Results{ + Tool: ctrf.Tool{Name: "bench-tool", Version: "1.0"}, + Summary: ctrf.Summary{Tests: 100, Passed: 90, Failed: 5, Skipped: 5, Pending: 0, Other: 0}, + Tests: tests, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + results := ctrf.Normalize(report, "bench-report-id", "bench-team-id") + _ = results + } +} + +// BenchmarkNormalize_1000Results measures CTRF normalization for 1000 test results. +func BenchmarkNormalize_1000Results(b *testing.B) { + tests := make([]ctrf.Test, 1000) + for i := range tests { + status := "passed" + if i%20 == 0 { + status = "failed" + } + if i%50 == 0 { + status = "skipped" + } + tests[i] = ctrf.Test{ + Name: "benchmark-test-name-that-is-reasonably-long", + Status: status, + Duration: float64(i * 10), + Message: "benchmark message", + Trace: "at benchmark.go:42", + FilePath: "benchmark/path/to/test.go", + Suite: "BenchmarkSuite", + Tags: []string{"smoke", "benchmark"}, + Retry: 0, + Flaky: false, + } + } + report := &ctrf.Report{ + Results: ctrf.Results{ + Tool: ctrf.Tool{Name: "bench-tool", Version: "1.0"}, + Summary: ctrf.Summary{Tests: 1000, Passed: 900, Failed: 50, Skipped: 50, Pending: 0, Other: 0}, + Tests: tests, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + results := ctrf.Normalize(report, "bench-report-id", "bench-team-id") + _ = results + } +} + +// BenchmarkBuildReportData_100Results measures building quality report data for 100 results. +func BenchmarkBuildReportData_100Results(b *testing.B) { + results := make([]model.TestResult, 100) + for i := range results { + status := "passed" + if i%10 == 0 { + status = "failed" + } + results[i] = model.TestResult{ + Name: "test-name", + Status: status, + DurationMs: int64(i * 50), + } + } + report := &ctrf.Report{ + Results: ctrf.Results{ + Summary: ctrf.Summary{Tests: 100, Passed: 90, Failed: 10}, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + data := buildReportData(report, results, nil) + _ = data + } +} + +// BenchmarkBuildReportData_1000Results measures building quality report data for 1000 results. +func BenchmarkBuildReportData_1000Results(b *testing.B) { + results := make([]model.TestResult, 1000) + for i := range results { + status := "passed" + if i%20 == 0 { + status = "failed" + } + results[i] = model.TestResult{ + Name: "test-name", + Status: status, + DurationMs: int64(i * 5), + } + } + report := &ctrf.Report{ + Results: ctrf.Results{ + Summary: ctrf.Summary{Tests: 1000, Passed: 900, Failed: 50, Skipped: 50}, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + data := buildReportData(report, results, nil) + _ = data + } +} diff --git a/internal/handler/reports_store_test.go b/internal/handler/reports_store_test.go new file mode 100644 index 00000000..b2dfffd2 --- /dev/null +++ b/internal/handler/reports_store_test.go @@ -0,0 +1,327 @@ +package handler + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5" + + "github.com/scaledtest/scaledtest/internal/model" + "github.com/scaledtest/scaledtest/internal/store" +) + +type mockReportsStore struct { + listFunc func(ctx context.Context, filter store.ReportListFilter) ([]map[string]interface{}, int, error) + createWithResultsFunc func(ctx context.Context, p store.CreateReportParams, results []model.TestResult) error + getFunc func(ctx context.Context, id, teamID string) (*model.TestReport, error) + deleteFunc func(ctx context.Context, id, teamID string) (int64, error) + executionExistsFunc func(ctx context.Context, executionID, teamID string) (bool, error) + getReportAndResultsFunc func(ctx context.Context, id, teamID string) (*model.TestReport, map[string]*model.TestResult, error) + getPreviousFailedTestsFunc func(ctx context.Context, teamID, currentReportID string) (map[string]bool, error) +} + +func (m *mockReportsStore) List(ctx context.Context, filter store.ReportListFilter) ([]map[string]interface{}, int, error) { + return m.listFunc(ctx, filter) +} +func (m *mockReportsStore) CreateWithResults(ctx context.Context, p store.CreateReportParams, results []model.TestResult) error { + return m.createWithResultsFunc(ctx, p, results) +} +func (m *mockReportsStore) Get(ctx context.Context, id, teamID string) (*model.TestReport, error) { + return m.getFunc(ctx, id, teamID) +} +func (m *mockReportsStore) Delete(ctx context.Context, id, teamID string) (int64, error) { + return m.deleteFunc(ctx, id, teamID) +} +func (m *mockReportsStore) ExecutionExists(ctx context.Context, executionID, teamID string) (bool, error) { + return m.executionExistsFunc(ctx, executionID, teamID) +} +func (m *mockReportsStore) GetReportAndResults(ctx context.Context, id, teamID string) (*model.TestReport, map[string]*model.TestResult, error) { + return m.getReportAndResultsFunc(ctx, id, teamID) +} +func (m *mockReportsStore) GetPreviousFailedTests(ctx context.Context, teamID, currentReportID string) (map[string]bool, error) { + return m.getPreviousFailedTestsFunc(ctx, teamID, currentReportID) +} + +func TestReportsHandler_Get_WithStore(t *testing.T) { + now := time.Now() + ms := &mockReportsStore{ + getFunc: func(_ context.Context, id, teamID string) (*model.TestReport, error) { + if id == "report-1" && teamID == "team-1" { + return &model.TestReport{ + ID: "report-1", + TeamID: "team-1", + ToolName: "jest", + Summary: json.RawMessage(`{"tests":5,"passed":4,"failed":1}`), + CreatedAt: now, + }, nil + } + return nil, pgx.ErrNoRows + }, + } + h := &ReportsHandler{ReportStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/api/v1/reports/report-1", nil) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "reportID", "report-1") + + h.Get(w, r) + + if w.Code != 200 { + t.Errorf("Get with store: status = %d, want 200 (body: %s)", w.Code, w.Body.String()) + } +} + +func TestReportsHandler_Get_WithStore_NotFound(t *testing.T) { + ms := &mockReportsStore{ + getFunc: func(_ context.Context, _, _ string) (*model.TestReport, error) { + return nil, pgx.ErrNoRows + }, + } + h := &ReportsHandler{ReportStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/api/v1/reports/nonexistent", nil) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "reportID", "nonexistent") + + h.Get(w, r) + + if w.Code != 404 { + t.Errorf("Get not found: status = %d, want 404", w.Code) + } +} + +func TestReportsHandler_Delete_WithStore(t *testing.T) { + ms := &mockReportsStore{ + deleteFunc: func(_ context.Context, id, teamID string) (int64, error) { + return 1, nil + }, + } + h := &ReportsHandler{ReportStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("DELETE", "/api/v1/reports/report-1", nil) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "reportID", "report-1") + + h.Delete(w, r) + + if w.Code != 200 { + t.Errorf("Delete with store: status = %d, want 200", w.Code) + } +} + +func TestReportsHandler_Delete_WithStore_NotFound(t *testing.T) { + ms := &mockReportsStore{ + deleteFunc: func(_ context.Context, _, _ string) (int64, error) { + return 0, nil + }, + } + h := &ReportsHandler{ReportStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("DELETE", "/api/v1/reports/nonexistent", nil) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + r = testWithChiParam(r, "reportID", "nonexistent") + + h.Delete(w, r) + + if w.Code != 404 { + t.Errorf("Delete not found: status = %d, want 404", w.Code) + } +} + +func TestReportsHandler_Create_WithStore_BulkInsert(t *testing.T) { + var capturedResults []model.TestResult + ms := &mockReportsStore{ + executionExistsFunc: func(_ context.Context, _, _ string) (bool, error) { + return false, nil + }, + createWithResultsFunc: func(_ context.Context, p store.CreateReportParams, results []model.TestResult) error { + capturedResults = results + return nil + }, + } + h := &ReportsHandler{ReportStore: ms} + w := httptest.NewRecorder() + report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":3,"passed":2,"failed":1,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":100},{"name":"t2","status":"passed","duration":200},{"name":"t3","status":"failed","duration":300,"message":"oops"}]}}` + r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(report)) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + + h.Create(w, r) + + if w.Code != 201 { + t.Errorf("Create with store: status = %d, want 201 (body: %s)", w.Code, w.Body.String()) + } + if len(capturedResults) != 3 { + t.Errorf("Create with store: expected 3 results, got %d", len(capturedResults)) + } +} + +func TestReportsHandler_List_WithStore(t *testing.T) { + ms := &mockReportsStore{ + listFunc: func(_ context.Context, filter store.ReportListFilter) ([]map[string]interface{}, int, error) { + reports := []map[string]interface{}{ + {"id": "r1", "tool_name": "jest", "total": 10}, + {"id": "r2", "tool_name": "mocha", "total": 20}, + } + return reports, 2, nil + }, + } + h := &ReportsHandler{ReportStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/api/v1/reports", nil) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + + h.List(w, r) + + if w.Code != 200 { + t.Errorf("List with store: status = %d, want 200", w.Code) + } +} + +func TestReportsHandler_Create_WithStore_ExecutionExists(t *testing.T) { + ms := &mockReportsStore{ + executionExistsFunc: func(_ context.Context, executionID, teamID string) (bool, error) { + return executionID == "550e8400-e29b-41d4-a716-446655440000" && teamID == "team-1", nil + }, + createWithResultsFunc: func(_ context.Context, p store.CreateReportParams, results []model.TestResult) error { + return nil + }, + } + h := &ReportsHandler{ReportStore: ms} + w := httptest.NewRecorder() + report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":1,"passed":1,"failed":0,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":10}]}}` + r := httptest.NewRequest("POST", "/api/v1/reports?execution_id=550e8400-e29b-41d4-a716-446655440000", strings.NewReader(report)) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + + h.Create(w, r) + + if w.Code != 201 { + t.Errorf("Create with exec_id: status = %d, want 201 (body: %s)", w.Code, w.Body.String()) + } +} + +func TestReportsHandler_Create_WithStore_ExecutionNotFound(t *testing.T) { + ms := &mockReportsStore{ + executionExistsFunc: func(_ context.Context, _, _ string) (bool, error) { + return false, nil + }, + } + h := &ReportsHandler{ReportStore: ms} + w := httptest.NewRecorder() + report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":1,"passed":1,"failed":0,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":10}]}}` + r := httptest.NewRequest("POST", "/api/v1/reports?execution_id=exec-nonexistent", strings.NewReader(report)) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + + h.Create(w, r) + + if w.Code != 400 { + t.Errorf("Create with bad exec_id: status = %d, want 400", w.Code) + } +} + +type mockAdminStore struct { + listUsersFunc func(ctx context.Context, limit, offset int) ([]model.User, int, error) +} + +func (m *mockAdminStore) ListUsers(ctx context.Context, limit, offset int) ([]model.User, int, error) { + return m.listUsersFunc(ctx, limit, offset) +} + +func TestReportsHandler_Create_WithStore_BulkInsert_LargeBatch(t *testing.T) { + var capturedResults []model.TestResult + ms := &mockReportsStore{ + executionExistsFunc: func(_ context.Context, _, _ string) (bool, error) { + return false, nil + }, + createWithResultsFunc: func(_ context.Context, p store.CreateReportParams, results []model.TestResult) error { + capturedResults = results + return nil + }, + } + h := &ReportsHandler{ReportStore: ms} + w := httptest.NewRecorder() + + var testsJSON strings.Builder + testsJSON.WriteString(`{"results":{"tool":{"name":"jest"},"summary":{"tests":200,"passed":180,"failed":15,"skipped":5,"pending":0,"other":0},"tests":[`) + for i := 0; i < 200; i++ { + if i > 0 { + testsJSON.WriteByte(',') + } + status := "passed" + if i%13 == 0 { + status = "failed" + } + if i%40 == 0 { + status = "skipped" + } + fmt.Fprintf(&testsJSON, `{"name":"bulk-test-%d","status":"%s","duration":%d}`, i, status, i*10) + } + testsJSON.WriteString(`]}}`) + + r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(testsJSON.String())) + r = testWithClaimsSimple(r, "user-1", "team-1", "owner") + + h.Create(w, r) + + if w.Code != 201 { + t.Errorf("Create with store bulk: status = %d, want 201 (body: %s)", w.Code, w.Body.String()) + } + if len(capturedResults) != 200 { + t.Errorf("Create with store bulk: expected 200 results passed in single CreateWithResults call, got %d", len(capturedResults)) + } +} + +func TestAdminHandler_ListUsers_WithStore(t *testing.T) { + ms := &mockAdminStore{ + listUsersFunc: func(_ context.Context, limit, offset int) ([]model.User, int, error) { + return []model.User{ + {ID: "u1", Email: "admin@test.com", DisplayName: "Admin", Role: "owner"}, + }, 1, nil + }, + } + h := &AdminHandler{AdminStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/api/v1/admin/users", nil) + r = testWithClaims(r, testClaims) + + h.ListUsers(w, r) + + if w.Code != http.StatusOK { + t.Errorf("ListUsers with store: status = %d, want %d", w.Code, http.StatusOK) + } + + var resp map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("unmarshal ListUsers: %v", err) + } + users, ok := resp["users"].([]interface{}) + if !ok { + t.Fatal("expected users array in response") + } + if len(users) != 1 { + t.Errorf("expected 1 user, got %d", len(users)) + } +} + +func TestAdminHandler_ListUsers_WithStore_Empty(t *testing.T) { + ms := &mockAdminStore{ + listUsersFunc: func(_ context.Context, _, _ int) ([]model.User, int, error) { + return nil, 0, nil + }, + } + h := &AdminHandler{AdminStore: ms} + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/api/v1/admin/users", nil) + r = testWithClaims(r, testClaims) + + h.ListUsers(w, r) + + if w.Code != http.StatusOK { + t.Errorf("ListUsers empty: status = %d, want %d", w.Code, http.StatusOK) + } +} diff --git a/internal/handler/reports_test.go b/internal/handler/reports_test.go index dc7c781c..3ffeb33e 100644 --- a/internal/handler/reports_test.go +++ b/internal/handler/reports_test.go @@ -15,9 +15,23 @@ import ( "github.com/scaledtest/scaledtest/internal/ctrf" "github.com/scaledtest/scaledtest/internal/model" + "github.com/scaledtest/scaledtest/internal/store" "github.com/scaledtest/scaledtest/internal/webhook" ) +func newCreateMockStore() *mockReportsStore { + return &mockReportsStore{ + createWithResultsFunc: func(_ context.Context, _ store.CreateReportParams, _ []model.TestResult) error { + return nil + }, + executionExistsFunc: func(_ context.Context, _, _ string) (bool, error) { + return true, nil + }, + getPreviousFailedTestsFunc: func(_ context.Context, _, _ string) (map[string]bool, error) { + return nil, nil + }, + } +} func TestListReports_Unauthorized(t *testing.T) { h := &ReportsHandler{} @@ -32,7 +46,7 @@ func TestListReports_Unauthorized(t *testing.T) { } func TestListReports_NoDB(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/reports", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -45,7 +59,7 @@ func TestListReports_NoDB(t *testing.T) { } func TestListReports_ValidSinceUntil_NoDB(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/reports?since=2024-01-01T00:00:00Z&until=2024-12-31T23:59:59Z", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -137,7 +151,7 @@ func TestCreateReport_InvalidCTRF(t *testing.T) { } func TestCreateReport_NoDB_Fallback(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":2,"passed":1,"failed":1,"skipped":0,"pending":0,"other":0},"tests":[{"name":"test1","status":"passed","duration":100},{"name":"test2","status":"failed","duration":200,"message":"oops"}]}}` r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(report)) @@ -145,24 +159,13 @@ func TestCreateReport_NoDB_Fallback(t *testing.T) { h.Create(w, r) - if w.Code != http.StatusCreated { - t.Errorf("Create without DB (fallback): got %d, want %d (body: %s)", w.Code, http.StatusCreated, w.Body.String()) - } - - var resp map[string]interface{} - if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - if resp["tool"] != "jest" { - t.Errorf("tool = %v, want jest", resp["tool"]) - } - if resp["tests"] != float64(2) { - t.Errorf("tests = %v, want 2", resp["tests"]) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Create without ReportStore: got %d, want %d (body: %s)", w.Code, http.StatusServiceUnavailable, w.Body.String()) } } func TestCreateReport_NoDB_WithExecutionID(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() report := `{"results":{"tool":{"name":"mocha"},"summary":{"tests":1,"passed":1,"failed":0,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":50}]}}` r := httptest.NewRequest("POST", "/api/v1/reports?execution_id=exec-123", strings.NewReader(report)) @@ -170,19 +173,13 @@ func TestCreateReport_NoDB_WithExecutionID(t *testing.T) { h.Create(w, r) - if w.Code != http.StatusCreated { - t.Errorf("Create with execution_id: got %d, want %d", w.Code, http.StatusCreated) - } - - var resp map[string]interface{} - json.NewDecoder(w.Body).Decode(&resp) - if resp["execution_id"] != "exec-123" { - t.Errorf("execution_id = %v, want exec-123", resp["execution_id"]) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Create with execution_id but no ReportStore: got %d, want %d", w.Code, http.StatusServiceUnavailable) } } func TestCreateReport_NoDB_WithTriageGitHubStatus(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":1,"passed":0,"failed":1,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"failed","duration":50}]}}` r := httptest.NewRequest("POST", "/api/v1/reports?triage_github_status=true", strings.NewReader(report)) @@ -190,14 +187,8 @@ func TestCreateReport_NoDB_WithTriageGitHubStatus(t *testing.T) { h.Create(w, r) - if w.Code != http.StatusCreated { - t.Errorf("Create with triage_github_status=true: got %d, want %d", w.Code, http.StatusCreated) - } - - var resp map[string]interface{} - json.NewDecoder(w.Body).Decode(&resp) - if resp["triage_github_status"] != true { - t.Errorf("triage_github_status = %v, want true", resp["triage_github_status"]) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Create with triage_github_status=true but no ReportStore: got %d, want %d", w.Code, http.StatusServiceUnavailable) } } @@ -229,7 +220,7 @@ func TestGetReport_MissingID(t *testing.T) { } func TestGetReport_NoDB(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/reports/abc", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -270,7 +261,7 @@ func TestDeleteReport_MissingID(t *testing.T) { } func TestDeleteReport_NoDB(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("DELETE", "/api/v1/reports/abc", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -291,12 +282,12 @@ func TestParsePagination(t *testing.T) { }{ {"", 50, 0}, {"?limit=10&offset=20", 10, 20}, - {"?limit=200", 50, 0}, // exceeds max - {"?limit=-1", 50, 0}, // negative - {"?limit=abc", 50, 0}, // non-numeric - {"?offset=-5", 50, 0}, // negative offset - {"?limit=100", 100, 0}, // max allowed - {"?limit=0", 50, 0}, // zero not allowed + {"?limit=200", 50, 0}, // exceeds max + {"?limit=-1", 50, 0}, // negative + {"?limit=abc", 50, 0}, // non-numeric + {"?offset=-5", 50, 0}, // negative offset + {"?limit=100", 100, 0}, // max allowed + {"?limit=0", 50, 0}, // zero not allowed } for _, tt := range tests { @@ -310,11 +301,11 @@ func TestParsePagination(t *testing.T) { } func TestNullString(t *testing.T) { - if got := nullString(""); got != nil { - t.Errorf("nullString(\"\") = %v, want nil", got) + if got := store.NullString(""); got != nil { + t.Errorf("NullString(\"\") = %v, want nil", got) } - if got := nullString("hello"); got == nil || *got != "hello" { - t.Errorf("nullString(\"hello\") = %v, want &\"hello\"", got) + if got := store.NullString("hello"); got == nil || *got != "hello" { + t.Errorf("NullString(\"hello\") = %v, want &\"hello\"", got) } } @@ -371,7 +362,7 @@ func TestCompareReports_SameID(t *testing.T) { } func TestCompareReports_NoDB(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/reports/compare?base=a&head=b", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -446,7 +437,7 @@ func TestCreateReport_TestMissingName(t *testing.T) { } func TestCreateReport_NoDB_AllTestStatuses(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() report := `{"results":{"tool":{"name":"pytest"},"summary":{"tests":5,"passed":1,"failed":1,"skipped":1,"pending":1,"other":1},"tests":[ {"name":"t1","status":"passed","duration":10}, @@ -460,25 +451,13 @@ func TestCreateReport_NoDB_AllTestStatuses(t *testing.T) { h.Create(w, r) - if w.Code != http.StatusCreated { - t.Errorf("Create with all statuses: got %d, want %d (body: %s)", w.Code, http.StatusCreated, w.Body.String()) - } - - var resp map[string]interface{} - json.NewDecoder(w.Body).Decode(&resp) - if resp["tool"] != "pytest" { - t.Errorf("tool = %v, want pytest", resp["tool"]) - } - if resp["tests"] != float64(5) { - t.Errorf("tests = %v, want 5", resp["tests"]) - } - if resp["message"] != "report accepted" { - t.Errorf("message = %v, want 'report accepted'", resp["message"]) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Create with all statuses but no ReportStore: got %d, want %d (body: %s)", w.Code, http.StatusServiceUnavailable, w.Body.String()) } } func TestCreateReport_NoDB_RichCTRFData(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() report := `{"results":{ "tool":{"name":"playwright","version":"1.40.0"}, @@ -494,22 +473,13 @@ func TestCreateReport_NoDB_RichCTRFData(t *testing.T) { h.Create(w, r) - if w.Code != http.StatusCreated { - t.Errorf("Create with rich CTRF: got %d, want %d (body: %s)", w.Code, http.StatusCreated, w.Body.String()) - } - - var resp map[string]interface{} - json.NewDecoder(w.Body).Decode(&resp) - if resp["tool"] != "playwright" { - t.Errorf("tool = %v, want playwright", resp["tool"]) - } - if resp["tests"] != float64(2) { - t.Errorf("tests = %v, want 2", resp["tests"]) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Create with rich CTRF but no ReportStore: got %d, want %d", w.Code, http.StatusServiceUnavailable) } } func TestCreateReport_NoDB_NoExecutionIDInResponse(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":1,"passed":1,"failed":0,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":10}]}}` r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(report)) @@ -517,14 +487,8 @@ func TestCreateReport_NoDB_NoExecutionIDInResponse(t *testing.T) { h.Create(w, r) - if w.Code != http.StatusCreated { - t.Errorf("got %d, want %d", w.Code, http.StatusCreated) - } - - var resp map[string]interface{} - json.NewDecoder(w.Body).Decode(&resp) - if _, ok := resp["execution_id"]; ok { - t.Error("execution_id should not be present when not provided") + if w.Code != http.StatusServiceUnavailable { + t.Errorf("got %d, want %d", w.Code, http.StatusServiceUnavailable) } } @@ -552,7 +516,7 @@ func TestListReports_DateFilterParams_NoDB(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/reports"+tt.query, nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -567,7 +531,7 @@ func TestListReports_DateFilterParams_NoDB(t *testing.T) { } func TestListReports_DateFilterWithPagination_NoDB(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/reports?since=2026-01-01T00:00:00Z&until=2026-06-01T00:00:00Z&limit=25&offset=10", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -584,7 +548,7 @@ func TestListReports_DateFilterWithPagination_NoDB(t *testing.T) { func TestListReports_RequiresClaims(t *testing.T) { // Without any claims, List must return 401 regardless of query params - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/reports?limit=10", nil) @@ -596,7 +560,6 @@ func TestListReports_RequiresClaims(t *testing.T) { } func TestCreateReport_DifferentTeams_NoDB(t *testing.T) { - // Verify that reports created by different teams work in no-DB fallback mode teams := []struct { teamID string tool string @@ -608,7 +571,7 @@ func TestCreateReport_DifferentTeams_NoDB(t *testing.T) { for _, tt := range teams { t.Run(tt.teamID, func(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() report := `{"results":{"tool":{"name":"` + tt.tool + `"},"summary":{"tests":1,"passed":1,"failed":0,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":10}]}}` r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(report)) @@ -616,8 +579,8 @@ func TestCreateReport_DifferentTeams_NoDB(t *testing.T) { h.Create(w, r) - if w.Code != http.StatusCreated { - t.Errorf("Create for %s: got %d, want %d", tt.teamID, w.Code, http.StatusCreated) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Create for %s: got %d, want %d", tt.teamID, w.Code, http.StatusServiceUnavailable) } }) } @@ -729,7 +692,7 @@ func TestParsePagination_AdditionalCases(t *testing.T) { wantLimit int wantOffset int }{ - {"?limit=1", 1, 0}, // minimum valid limit + {"?limit=1", 1, 0}, // minimum valid limit {"?limit=99&offset=0", 99, 0}, {"?limit=50&offset=1000", 50, 1000}, // large offset {"?limit=101", 50, 0}, // just over max @@ -958,24 +921,13 @@ func TestBuildReportData_WithPreviousFailedTests(t *testing.T) { } } -func TestFetchPreviousFailedTests_NilDB(t *testing.T) { - // When DB is nil, fetchPreviousFailedTests must return (nil, nil) gracefully. - result, err := fetchPreviousFailedTests(context.Background(), nil, "team-1", "report-1") - if err != nil { - t.Errorf("expected nil error when DB is nil, got %v", err) - } - if result != nil { - t.Errorf("expected nil map when DB is nil, got %v", result) - } -} - // --- Mock webhook lister for testing dispatch --- type mockWebhookLister struct { - mu sync.Mutex - calls []mockWebhookCall - hooks []webhook.WebhookRecord - err error + mu sync.Mutex + calls []mockWebhookCall + hooks []webhook.WebhookRecord + err error } type mockWebhookCall struct { @@ -1000,13 +952,12 @@ func (m *mockWebhookLister) getCalls() []mockWebhookCall { // --- Webhook dispatch tests --- -func TestCreateReport_NoDB_WebhookNotSkippedInFallback(t *testing.T) { - // In no-DB fallback mode, webhook dispatch is skipped (return before dispatch code). - // Verify the handler doesn't call the lister in this path. +func TestCreateReport_NoDB_Returns503(t *testing.T) { + // Without ReportStore, Create should return 503 lister := &mockWebhookLister{} notifier := webhook.NewNotifier(lister, webhook.NewDispatcher()) - h := &ReportsHandler{DB: nil, Webhooks: notifier} + h := &ReportsHandler{Webhooks: notifier} w := httptest.NewRecorder() report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":1,"passed":1,"failed":0,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":10}]}}` r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(report)) @@ -1014,21 +965,21 @@ func TestCreateReport_NoDB_WebhookNotSkippedInFallback(t *testing.T) { h.Create(w, r) - if w.Code != http.StatusCreated { - t.Fatalf("Create: got %d, want %d (body: %s)", w.Code, http.StatusCreated, w.Body.String()) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("Create: got %d, want %d (body: %s)", w.Code, http.StatusServiceUnavailable, w.Body.String()) } - // In no-DB fallback, webhook dispatch should NOT be called + // No webhook dispatch should happen time.Sleep(100 * time.Millisecond) calls := lister.getCalls() if len(calls) > 0 { - t.Errorf("expected no webhook calls in no-DB fallback, got %d", len(calls)) + t.Errorf("expected no webhook calls without ReportStore, got %d", len(calls)) } } func TestCreateReport_NoDB_NilWebhookNotifierSafe(t *testing.T) { - // Webhooks is nil — should not panic - h := &ReportsHandler{DB: nil, Webhooks: nil} + // Without ReportStore, Create should return 503 + h := &ReportsHandler{Webhooks: nil} w := httptest.NewRecorder() report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":1,"passed":1,"failed":0,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":10}]}}` r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(report)) @@ -1036,15 +987,15 @@ func TestCreateReport_NoDB_NilWebhookNotifierSafe(t *testing.T) { h.Create(w, r) - if w.Code != http.StatusCreated { - t.Errorf("Create with nil Webhooks: got %d, want %d (body: %s)", w.Code, http.StatusCreated, w.Body.String()) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Create with nil Webhooks and no ReportStore: got %d, want %d", w.Code, http.StatusServiceUnavailable) } } // --- Quality gate auto-evaluation tests --- func TestCreateReport_NoDB_QualityGateNotEvaluatedWhenStoreNil(t *testing.T) { - h := &ReportsHandler{DB: nil, QualityGateStore: nil} + h := &ReportsHandler{QualityGateStore: nil} w := httptest.NewRecorder() report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":1,"passed":1,"failed":0,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":10}]}}` r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(report)) @@ -1052,14 +1003,8 @@ func TestCreateReport_NoDB_QualityGateNotEvaluatedWhenStoreNil(t *testing.T) { h.Create(w, r) - if w.Code != http.StatusCreated { - t.Fatalf("Create: got %d, want %d", w.Code, http.StatusCreated) - } - - var resp map[string]interface{} - json.NewDecoder(w.Body).Decode(&resp) - if _, ok := resp["qualityGate"]; ok { - t.Error("qualityGate should not be present when QualityGateStore is nil") + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("Create: got %d, want %d", w.Code, http.StatusServiceUnavailable) } } @@ -1067,10 +1012,10 @@ func TestCreateReport_NoDB_QualityGateNotEvaluatedWhenStoreNil(t *testing.T) { // mockQualityGateStore implements qualityGateEvaluator for testing. type mockQualityGateStore struct { - gates []model.QualityGate - listErr error - evalCalls []mockEvalCall - createErr error + gates []model.QualityGate + listErr error + evalCalls []mockEvalCall + createErr error } type mockEvalCall struct { @@ -1123,7 +1068,11 @@ func TestEvaluateQualityGates_TeamScopedAPIToken(t *testing.T) { }, } - h := &ReportsHandler{QualityGateStore: mockStore} + h := &ReportsHandler{QualityGateStore: mockStore, ReportStore: &mockReportsStore{ + getPreviousFailedTestsFunc: func(_ context.Context, _, _ string) (map[string]bool, error) { + return nil, nil + }, + }} report := &ctrf.Report{ Results: ctrf.Results{ @@ -1199,7 +1148,11 @@ func TestEvaluateQualityGates_PassingGateViaAPIToken(t *testing.T) { }, } - h := &ReportsHandler{QualityGateStore: mockStore} + h := &ReportsHandler{QualityGateStore: mockStore, ReportStore: &mockReportsStore{ + getPreviousFailedTestsFunc: func(_ context.Context, _, _ string) (map[string]bool, error) { + return nil, nil + }, + }} report := &ctrf.Report{ Results: ctrf.Results{ @@ -1236,7 +1189,11 @@ func TestEvaluateQualityGates_NoGatesForTeam(t *testing.T) { }, } - h := &ReportsHandler{QualityGateStore: mockStore} + h := &ReportsHandler{QualityGateStore: mockStore, ReportStore: &mockReportsStore{ + getPreviousFailedTestsFunc: func(_ context.Context, _, _ string) (map[string]bool, error) { + return nil, nil + }, + }} report := &ctrf.Report{ Results: ctrf.Results{ Tool: ctrf.Tool{Name: "test"}, @@ -1265,7 +1222,11 @@ func TestEvaluateQualityGates_MultipleGates(t *testing.T) { }, } - h := &ReportsHandler{QualityGateStore: mockStore} + h := &ReportsHandler{QualityGateStore: mockStore, ReportStore: &mockReportsStore{ + getPreviousFailedTestsFunc: func(_ context.Context, _, _ string) (map[string]bool, error) { + return nil, nil + }, + }} report := &ctrf.Report{ Results: ctrf.Results{ Tool: ctrf.Tool{Name: "test"}, @@ -1307,7 +1268,7 @@ func TestEvaluateQualityGates_MultipleGates(t *testing.T) { // --- Nil DB returns 503 tests --- func TestCreateReport_NilDB_Returns503ForGet(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/reports/some-id", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -1321,7 +1282,7 @@ func TestCreateReport_NilDB_Returns503ForGet(t *testing.T) { } func TestDeleteReport_NilDB_Returns503(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("DELETE", "/api/v1/reports/some-id", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -1335,7 +1296,7 @@ func TestDeleteReport_NilDB_Returns503(t *testing.T) { } func TestCompareReports_NilDB_Returns503(t *testing.T) { - h := &ReportsHandler{DB: nil} + h := &ReportsHandler{} w := httptest.NewRecorder() r := httptest.NewRequest("GET", "/api/v1/reports/compare?base=a&head=b", nil) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -1399,7 +1360,7 @@ const failingReport = `{"results":{"tool":{"name":"playwright"},"summary":{"test func TestCreateReport_PostsGitHubStatus_AllPassed(t *testing.T) { poster := &mockGitHubStatusPoster{} - h := &ReportsHandler{DB: nil, GitHubStatusPoster: poster} + h := &ReportsHandler{ReportStore: newCreateMockStore(), GitHubStatusPoster: poster} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/reports?github_owner=acme&github_repo=app&github_sha=abc1234", strings.NewReader(validReport)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -1431,7 +1392,7 @@ func TestCreateReport_PostsGitHubStatus_AllPassed(t *testing.T) { func TestCreateReport_PostsGitHubStatus_WithFailures(t *testing.T) { poster := &mockGitHubStatusPoster{} - h := &ReportsHandler{DB: nil, GitHubStatusPoster: poster} + h := &ReportsHandler{ReportStore: newCreateMockStore(), GitHubStatusPoster: poster} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/reports?github_owner=acme&github_repo=app&github_sha=abc1234", strings.NewReader(failingReport)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -1451,7 +1412,7 @@ func TestCreateReport_PostsGitHubStatus_WithFailures(t *testing.T) { func TestCreateReport_GitHubStatusError_IsNonFatal(t *testing.T) { poster := &mockGitHubStatusPoster{err: fmt.Errorf("github API down")} - h := &ReportsHandler{DB: nil, GitHubStatusPoster: poster} + h := &ReportsHandler{ReportStore: newCreateMockStore(), GitHubStatusPoster: poster} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/reports?github_owner=acme&github_repo=app&github_sha=abc1234", strings.NewReader(validReport)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -1470,7 +1431,7 @@ func TestCreateReport_GitHubStatusError_IsNonFatal(t *testing.T) { func TestCreateReport_NoGitHubParams_NoPosterCalled(t *testing.T) { poster := &mockGitHubStatusPoster{} - h := &ReportsHandler{DB: nil, GitHubStatusPoster: poster} + h := &ReportsHandler{ReportStore: newCreateMockStore(), GitHubStatusPoster: poster} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(validReport)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -1487,7 +1448,7 @@ func TestCreateReport_NoGitHubParams_NoPosterCalled(t *testing.T) { } func TestCreateReport_NilGitHubPoster_NoError(t *testing.T) { - h := &ReportsHandler{DB: nil, GitHubStatusPoster: nil} + h := &ReportsHandler{ReportStore: newCreateMockStore(), GitHubStatusPoster: nil} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/reports?github_owner=acme&github_repo=app&github_sha=abc1234", strings.NewReader(validReport)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -1501,10 +1462,10 @@ func TestCreateReport_NilGitHubPoster_NoError(t *testing.T) { func TestCreateReport_GitHubStatus_WithExecutionID_LinksToExecution(t *testing.T) { poster := &mockGitHubStatusPoster{} - h := &ReportsHandler{DB: nil, GitHubStatusPoster: poster, BaseURL: "http://example.com"} + h := &ReportsHandler{ReportStore: newCreateMockStore(), GitHubStatusPoster: poster, BaseURL: "http://example.com"} w := httptest.NewRecorder() r := httptest.NewRequest("POST", - "/api/v1/reports?github_owner=acme&github_repo=app&github_sha=abc1234&execution_id=exec-uuid-123", + "/api/v1/reports?github_owner=acme&github_repo=app&github_sha=abc1234&execution_id=00000000-0000-0000-0000-000000000123", strings.NewReader(validReport)) r = testWithClaimsSimple(r, "user-1", "team-1", "owner") @@ -1516,18 +1477,18 @@ func TestCreateReport_GitHubStatus_WithExecutionID_LinksToExecution(t *testing.T eventually(t, 500, func() bool { return poster.callCount() == 1 }, "GitHub status not posted") call, _ := poster.firstCall() - wantURL := "http://example.com/executions/exec-uuid-123" + wantURL := "http://example.com/executions/00000000-0000-0000-0000-000000000123" if call.TargetURL != wantURL { t.Errorf("targetURL = %q, want %q", call.TargetURL, wantURL) } - if !strings.Contains(call.Description, "exec-uuid-123") { + if !strings.Contains(call.Description, "00000000-0000-0000-0000-000000000123") { t.Errorf("description %q should contain execution ID", call.Description) } } func TestCreateReport_GitHubStatus_WithoutExecutionID_LinksToReport(t *testing.T) { poster := &mockGitHubStatusPoster{} - h := &ReportsHandler{DB: nil, GitHubStatusPoster: poster, BaseURL: "http://example.com"} + h := &ReportsHandler{ReportStore: newCreateMockStore(), GitHubStatusPoster: poster, BaseURL: "http://example.com"} w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/reports?github_owner=acme&github_repo=app&github_sha=abc1234", @@ -1542,10 +1503,12 @@ func TestCreateReport_GitHubStatus_WithoutExecutionID_LinksToReport(t *testing.T eventually(t, 500, func() bool { return poster.callCount() == 1 }, "GitHub status not posted") call, _ := poster.firstCall() - // Without execution_id and no reportID (no-DB path), targetURL should be empty if strings.Contains(call.TargetURL, "/executions/") { t.Errorf("targetURL %q should not link to execution when no execution_id", call.TargetURL) } + if !strings.Contains(call.TargetURL, "/reports/") { + t.Errorf("targetURL %q should link to report", call.TargetURL) + } } // --------------------------------------------------------------------------- @@ -1590,9 +1553,9 @@ func TestFlattenReportForList_PromotesSummaryFields(t *testing.T) { func TestFlattenReportForList_ZeroCounts(t *testing.T) { rpt := model.TestReport{ - ID: "report-2", - TeamID: "team-1", - Summary: json.RawMessage(`{"tests":0,"passed":0,"failed":0,"skipped":0,"pending":0,"other":0}`), + ID: "report-2", + TeamID: "team-1", + Summary: json.RawMessage(`{"tests":0,"passed":0,"failed":0,"skipped":0,"pending":0,"other":0}`), } out := flattenReportForList(rpt) @@ -2337,7 +2300,7 @@ func TestRetryTriage_ForceResetNoOp_Returns202WithoutEnqueue(t *testing.T) { // TestCreateReport_TriageEnqueuer_NilEnqueuer_NoDB verifies that the handler // does not panic when TriageEnqueuer is nil (triage disabled). func TestCreateReport_TriageEnqueuer_NilEnqueuer_NoDB(t *testing.T) { - h := &ReportsHandler{DB: nil, TriageEnqueuer: nil} + h := &ReportsHandler{ReportStore: newCreateMockStore(), TriageEnqueuer: nil} report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":1,"passed":1,"failed":0,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":10}]}}` w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(report)) @@ -2351,11 +2314,11 @@ func TestCreateReport_TriageEnqueuer_NilEnqueuer_NoDB(t *testing.T) { } } -// TestCreateReport_TriageEnqueuer_NotCalledOnNoDB verifies that Enqueue is -// not invoked when there is no database (no persistent report to triage). -func TestCreateReport_TriageEnqueuer_NotCalledOnNoDB(t *testing.T) { +// TestCreateReport_TriageEnqueuer_CalledWhenStorePresent verifies that Enqueue +// is invoked when the report is persisted via ReportStore. +func TestCreateReport_TriageEnqueuer_CalledWhenStorePresent(t *testing.T) { enqueuer := &capTriageEnqueuer{} - h := &ReportsHandler{DB: nil, TriageEnqueuer: enqueuer} + h := &ReportsHandler{ReportStore: newCreateMockStore(), TriageEnqueuer: enqueuer} report := `{"results":{"tool":{"name":"jest"},"summary":{"tests":1,"passed":1,"failed":0,"skipped":0,"pending":0,"other":0},"tests":[{"name":"t1","status":"passed","duration":10}]}}` w := httptest.NewRecorder() r := httptest.NewRequest("POST", "/api/v1/reports", strings.NewReader(report)) @@ -2366,7 +2329,7 @@ func TestCreateReport_TriageEnqueuer_NotCalledOnNoDB(t *testing.T) { if w.Code != http.StatusCreated { t.Fatalf("unexpected status: %d", w.Code) } - if enqueuer.count() != 0 { - t.Errorf("Enqueue should not be called in no-DB mode; got %d calls", enqueuer.count()) + if enqueuer.count() != 1 { + t.Errorf("Enqueue should be called once with ReportStore; got %d calls", enqueuer.count()) } } diff --git a/internal/handler/store_interfaces.go b/internal/handler/store_interfaces.go new file mode 100644 index 00000000..352cdd0f --- /dev/null +++ b/internal/handler/store_interfaces.go @@ -0,0 +1,75 @@ +package handler + +import ( + "context" + "net" + "time" + + "github.com/scaledtest/scaledtest/internal/model" + "github.com/scaledtest/scaledtest/internal/store" +) + +// authStore abstracts auth persistence operations. +type authStore interface { + GetUserByEmail(ctx context.Context, email string) (*model.User, error) + GetUserByID(ctx context.Context, id string) (*model.User, error) + EmailExists(ctx context.Context, email string) (bool, error) + CreateUser(ctx context.Context, email, passwordHash, displayName, role string) (userID, returnedRole string, err error) + CreateUserWithRole(ctx context.Context, email, passwordHash, displayName, role string) (userID string, err error) + UpdatePassword(ctx context.Context, userID, passwordHash string) (int64, error) + UpdateProfile(ctx context.Context, userID, displayName string) (*model.User, error) + GetPrimaryTeamID(ctx context.Context, userID string) (string, error) + CreateSession(ctx context.Context, userID, refreshToken, userAgent string, ipAddr net.IP, expiresAt time.Time) error + GetSessionByRefreshToken(ctx context.Context, refreshToken string) (*store.SessionInfo, error) + DeleteSession(ctx context.Context, sessionID string) error + DeleteSessionByRefreshToken(ctx context.Context, refreshToken string) error +} + +// analyticsStore abstracts analytics query operations. +type analyticsStore interface { + QueryTrends(ctx context.Context, groupBy, teamID string, start, end time.Time) ([]store.TrendRow, error) + QueryDurationBuckets(ctx context.Context, teamID string, start, end time.Time) ([]int64, error) + QueryErrorClusters(ctx context.Context, teamID string, start, end time.Time, limit int) ([]store.ErrorClusterRow, error) + QueryFlakyTests(ctx context.Context, teamID string, cutoff time.Time, minRuns int) ([]store.FlakyRow, error) +} + +// executionsStore abstracts execution persistence operations. +type executionsStore interface { + List(ctx context.Context, teamID string, limit, offset int) ([]model.TestExecution, int, error) + Create(ctx context.Context, teamID, command string, configJSON []byte) (string, error) + Get(ctx context.Context, id, teamID string) (*model.TestExecution, error) + Cancel(ctx context.Context, id, teamID string, now time.Time) (int64, error) + UpdateStatus(ctx context.Context, id, teamID, status string, now time.Time, errorMsg *string) (int64, error) + Exists(ctx context.Context, id, teamID string) (bool, error) + GetK8sJobName(ctx context.Context, id string) (*string, error) + SetK8sJobName(ctx context.Context, id, jobName string, now time.Time) error + MarkFailed(ctx context.Context, id, errorMsg string, now time.Time) error +} + +// reportsStore abstracts report persistence operations. +type reportsStore interface { + List(ctx context.Context, filter store.ReportListFilter) ([]map[string]interface{}, int, error) + CreateWithResults(ctx context.Context, p store.CreateReportParams, results []model.TestResult) error + Get(ctx context.Context, id, teamID string) (*model.TestReport, error) + Delete(ctx context.Context, id, teamID string) (int64, error) + ExecutionExists(ctx context.Context, executionID, teamID string) (bool, error) + GetReportAndResults(ctx context.Context, id, teamID string) (*model.TestReport, map[string]*model.TestResult, error) + GetPreviousFailedTests(ctx context.Context, teamID, currentReportID string) (map[string]bool, error) +} + +// teamsStore abstracts team and token data operations for testable handlers. +type teamsStore interface { + ListTeams(ctx context.Context, userID string) ([]store.TeamWithRole, error) + GetTeam(ctx context.Context, teamID, userID string) (*store.TeamWithRole, error) + GetUserRole(ctx context.Context, userID, teamID string) (string, error) + CreateTeam(ctx context.Context, userID, name string) (*model.Team, error) + DeleteTeam(ctx context.Context, teamID string) error + ListTokens(ctx context.Context, teamID string) ([]model.APIToken, error) + CreateToken(ctx context.Context, teamID, userID, name, tokenHash, prefix string) (*model.APIToken, error) + DeleteToken(ctx context.Context, teamID, tokenID string) (int64, error) +} + +// adminStore abstracts admin query operations. +type adminStore interface { + ListUsers(ctx context.Context, limit, offset int) ([]model.User, int, error) +} diff --git a/internal/handler/teams.go b/internal/handler/teams.go index 642c8888..f19d80ea 100644 --- a/internal/handler/teams.go +++ b/internal/handler/teams.go @@ -1,31 +1,19 @@ package handler import ( - "context" "net/http" "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5" "github.com/scaledtest/scaledtest/internal/auth" - "github.com/scaledtest/scaledtest/internal/db" "github.com/scaledtest/scaledtest/internal/model" "github.com/scaledtest/scaledtest/internal/sanitize" "github.com/scaledtest/scaledtest/internal/store" ) -// teamsStore abstracts team and token data operations for testable handlers. -type teamsStore interface { - CreateTeam(ctx context.Context, userID, name string) (*model.Team, error) - GetUserRole(ctx context.Context, userID, teamID string) (string, error) - DeleteTeam(ctx context.Context, teamID string) error - CreateToken(ctx context.Context, teamID, userID, name, tokenHash, prefix string) (*model.APIToken, error) - DeleteToken(ctx context.Context, teamID, tokenID string) (int64, error) -} - // TeamsHandler handles team management endpoints. type TeamsHandler struct { - DB *db.Pool Store teamsStore AuditStore auditLogger } @@ -48,39 +36,18 @@ func (h *TeamsHandler) List(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.Store == nil { JSON(w, http.StatusOK, map[string]interface{}{"teams": []interface{}{}}) return } - rows, err := h.DB.Query(r.Context(), - `SELECT t.id, t.name, t.created_at, ut.role - FROM teams t - JOIN user_teams ut ON ut.team_id = t.id - WHERE ut.user_id = $1 - ORDER BY t.name`, claims.UserID) + teams, err := h.Store.ListTeams(r.Context(), claims.UserID) if err != nil { Error(w, http.StatusInternalServerError, "failed to list teams") return } - defer rows.Close() - - type teamWithRole struct { - model.Team - Role string `json:"role"` - } - - var teams []teamWithRole - for rows.Next() { - var t teamWithRole - if err := rows.Scan(&t.ID, &t.Name, &t.CreatedAt, &t.Role); err != nil { - Error(w, http.StatusInternalServerError, "failed to scan team") - return - } - teams = append(teams, t) - } if teams == nil { - teams = []teamWithRole{} + teams = []store.TeamWithRole{} } JSON(w, http.StatusOK, map[string]interface{}{"teams": teams}) @@ -140,20 +107,12 @@ func (h *TeamsHandler) Get(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.Store == nil { Error(w, http.StatusNotImplemented, "get team requires database connection") return } - // Verify membership and get team - var team model.Team - var role string - err := h.DB.QueryRow(r.Context(), - `SELECT t.id, t.name, t.created_at, ut.role - FROM teams t - JOIN user_teams ut ON ut.team_id = t.id - WHERE t.id = $1 AND ut.user_id = $2`, teamID, claims.UserID). - Scan(&team.ID, &team.Name, &team.CreatedAt, &role) + result, err := h.Store.GetTeam(r.Context(), teamID, claims.UserID) if err == pgx.ErrNoRows { Error(w, http.StatusNotFound, "team not found") return @@ -164,8 +123,8 @@ func (h *TeamsHandler) Get(w http.ResponseWriter, r *http.Request) { } JSON(w, http.StatusOK, map[string]interface{}{ - "team": team, - "role": role, + "team": result.Team, + "role": result.Role, }) } @@ -234,13 +193,13 @@ func (h *TeamsHandler) ListTokens(w http.ResponseWriter, r *http.Request) { return } - if h.DB == nil { + if h.Store == nil { JSON(w, http.StatusOK, map[string]interface{}{"tokens": []interface{}{}}) return } // Verify team membership - _, err := h.getUserTeamRole(r.Context(), claims.UserID, teamID) + _, err := h.Store.GetUserRole(r.Context(), claims.UserID, teamID) if err == pgx.ErrNoRows { Error(w, http.StatusNotFound, "team not found") return @@ -250,26 +209,11 @@ func (h *TeamsHandler) ListTokens(w http.ResponseWriter, r *http.Request) { return } - rows, err := h.DB.Query(r.Context(), - `SELECT id, team_id, user_id, name, prefix, last_used_at, created_at - FROM api_tokens - WHERE team_id = $1 - ORDER BY created_at DESC`, teamID) + tokens, err := h.Store.ListTokens(r.Context(), teamID) if err != nil { Error(w, http.StatusInternalServerError, "failed to list tokens") return } - defer rows.Close() - - var tokens []model.APIToken - for rows.Next() { - var t model.APIToken - if err := rows.Scan(&t.ID, &t.TeamID, &t.UserID, &t.Name, &t.Prefix, &t.LastUsedAt, &t.CreatedAt); err != nil { - Error(w, http.StatusInternalServerError, "failed to scan token") - return - } - tokens = append(tokens, t) - } if tokens == nil { tokens = []model.APIToken{} } @@ -414,12 +358,3 @@ func (h *TeamsHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { JSON(w, http.StatusOK, map[string]string{"message": "token revoked"}) } - -// getUserTeamRole returns the user's role in the team, or an error if not a member. -func (h *TeamsHandler) getUserTeamRole(ctx context.Context, userID, teamID string) (string, error) { - var role string - err := h.DB.QueryRow(ctx, - `SELECT role FROM user_teams WHERE user_id = $1 AND team_id = $2`, - userID, teamID).Scan(&role) - return role, err -} diff --git a/internal/handler/teams_test.go b/internal/handler/teams_test.go index 0a902eee..a8ce2ce4 100644 --- a/internal/handler/teams_test.go +++ b/internal/handler/teams_test.go @@ -10,9 +10,11 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/jackc/pgx/v5" "github.com/scaledtest/scaledtest/internal/auth" "github.com/scaledtest/scaledtest/internal/model" + "github.com/scaledtest/scaledtest/internal/store" ) var testClaims = &auth.Claims{ @@ -239,16 +241,28 @@ func TestDeleteToken_NoDB(t *testing.T) { } } -// mockTeamsStore implements teamsStore for audit logging tests. +// mockTeamsStore implements teamsStore for unit tests. type mockTeamsStore struct { - team *model.Team - token *model.APIToken - role string - roleErr error - delErr error - delRows int64 + team *model.Team + teamWithRole *store.TeamWithRole + teamsList []store.TeamWithRole + token *model.APIToken + tokensList []model.APIToken + role string + roleErr error + delErr error + delRows int64 } +func (m *mockTeamsStore) ListTeams(_ context.Context, _ string) ([]store.TeamWithRole, error) { + return m.teamsList, nil +} +func (m *mockTeamsStore) GetTeam(_ context.Context, _, _ string) (*store.TeamWithRole, error) { + if m.teamWithRole != nil { + return m.teamWithRole, nil + } + return nil, pgx.ErrNoRows +} func (m *mockTeamsStore) CreateTeam(_ context.Context, _, _ string) (*model.Team, error) { return m.team, nil } @@ -261,6 +275,9 @@ func (m *mockTeamsStore) DeleteTeam(_ context.Context, _ string) error { return m.delErr } +func (m *mockTeamsStore) ListTokens(_ context.Context, _ string) ([]model.APIToken, error) { + return m.tokensList, nil +} func (m *mockTeamsStore) CreateToken(_ context.Context, _, _, _, _, _ string) (*model.APIToken, error) { return m.token, nil } @@ -271,6 +288,142 @@ func (m *mockTeamsStore) DeleteToken(_ context.Context, _, _ string) (int64, err // --- Audit logging tests --- +func TestTeamsHandler_List_WithStore(t *testing.T) { + ms := &mockTeamsStore{ + teamsList: []store.TeamWithRole{ + {Team: model.Team{ID: "team-1", Name: "Alpha"}, Role: "owner"}, + {Team: model.Team{ID: "team-2", Name: "Beta"}, Role: "maintainer"}, + }, + } + h := &TeamsHandler{Store: ms, AuditStore: &capAuditLogger{}} + + req := httptest.NewRequest("GET", "/api/v1/teams", nil) + req = testWithClaims(req, testClaims) + w := httptest.NewRecorder() + + h.List(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("List with store: got %d, want %d (body: %s)", w.Code, http.StatusOK, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + teams, ok := resp["teams"].([]interface{}) + if !ok || len(teams) != 2 { + t.Errorf("expected 2 teams, got %v", resp["teams"]) + } +} + +func TestTeamsHandler_List_WithStore_Empty(t *testing.T) { + ms := &mockTeamsStore{} + h := &TeamsHandler{Store: ms, AuditStore: &capAuditLogger{}} + + req := httptest.NewRequest("GET", "/api/v1/teams", nil) + req = testWithClaims(req, testClaims) + w := httptest.NewRecorder() + + h.List(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("List empty: got %d, want %d", w.Code, http.StatusOK) + } + + var resp map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + teams, ok := resp["teams"].([]interface{}) + if !ok || len(teams) != 0 { + t.Errorf("expected empty teams array, got %v", resp["teams"]) + } +} + +func TestTeamsHandler_Get_WithStore_Found(t *testing.T) { + twr := &store.TeamWithRole{Team: model.Team{ID: "team-1", Name: "Alpha"}, Role: "owner"} + ms := &mockTeamsStore{teamWithRole: twr} + h := &TeamsHandler{Store: ms, AuditStore: &capAuditLogger{}} + + req := httptest.NewRequest("GET", "/api/v1/teams/team-1", nil) + req = testWithClaimsAndParam(req, testClaims, "teamID", "team-1") + w := httptest.NewRecorder() + + h.Get(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Get with store: got %d, want %d (body: %s)", w.Code, http.StatusOK, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp["role"] != "owner" { + t.Errorf("role = %v, want owner", resp["role"]) + } +} + +func TestTeamsHandler_Get_WithStore_NotFound(t *testing.T) { + ms := &mockTeamsStore{} + h := &TeamsHandler{Store: ms, AuditStore: &capAuditLogger{}} + + req := httptest.NewRequest("GET", "/api/v1/teams/team-1", nil) + req = testWithClaimsAndParam(req, testClaims, "teamID", "team-1") + w := httptest.NewRecorder() + + h.Get(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("Get not found: got %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestTeamsHandler_ListTokens_WithStore(t *testing.T) { + ms := &mockTeamsStore{ + role: "owner", + tokensList: []model.APIToken{ + {ID: "tok-1", TeamID: "team-1", Name: "ci", Prefix: "sct_ci"}, + }, + } + h := &TeamsHandler{Store: ms, AuditStore: &capAuditLogger{}} + + req := httptest.NewRequest("GET", "/api/v1/teams/team-1/tokens", nil) + req = testWithClaimsAndParam(req, testClaims, "teamID", "team-1") + w := httptest.NewRecorder() + + h.ListTokens(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("ListTokens with store: got %d, want %d (body: %s)", w.Code, http.StatusOK, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + tokens, ok := resp["tokens"].([]interface{}) + if !ok || len(tokens) != 1 { + t.Errorf("expected 1 token, got %v", resp["tokens"]) + } +} + +func TestTeamsHandler_ListTokens_WithStore_NotMember(t *testing.T) { + ms := &mockTeamsStore{roleErr: pgx.ErrNoRows} + h := &TeamsHandler{Store: ms, AuditStore: &capAuditLogger{}} + + req := httptest.NewRequest("GET", "/api/v1/teams/team-1/tokens", nil) + req = testWithClaimsAndParam(req, testClaims, "teamID", "team-1") + w := httptest.NewRecorder() + + h.ListTokens(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("ListTokens not member: got %d, want %d", w.Code, http.StatusNotFound) + } +} + func TestTeamsCreate_LogsAuditEvent(t *testing.T) { team := &model.Team{ID: "team-1", Name: "My Team", CreatedAt: time.Now()} ms := &mockTeamsStore{team: team} diff --git a/internal/server/routes.go b/internal/server/routes.go index 060d8aeb..5df650cc 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -162,13 +162,15 @@ func NewRouter(cfg *config.Config, pool ...*db.Pool) http.Handler { } // Handlers - oauthH := &handler.OAuthHandler{JWT: jwtMgr, DB: dbPool, OAuth: oauthCfgs, Secure: isSecure} + oauthH := &handler.OAuthHandler{JWT: jwtMgr, OAuth: oauthCfgs, Secure: isSecure} + if dbPool != nil { + oauthH.OAuthStore = store.NewOAuthStore(dbPool) + } authH := &handler.AuthHandler{JWT: jwtMgr} if dbPool != nil { - authH.DB = dbPool + authH.AuthStore = store.NewAuthStore(dbPool) } reportsH := &handler.ReportsHandler{ - DB: dbPool, AuditStore: auditStore, QualityGateStore: qgStore, Webhooks: whNotifier, @@ -177,11 +179,13 @@ func NewRouter(cfg *config.Config, pool ...*db.Pool) http.Handler { TriageEnqueuer: triageEnqueuer, AllowBackdate: cfg.DisableRateLimit, } + if dbPool != nil { + reportsH.ReportStore = store.NewReportsStore(dbPool) + } if triageStore != nil { reportsH.TriageStore = triageStore } execH := &handler.ExecutionsHandler{ - DB: dbPool, Hub: wsHub, AuditStore: auditStore, K8s: k8sClient, @@ -190,17 +194,29 @@ func NewRouter(cfg *config.Config, pool ...*db.Pool) http.Handler { APIBaseURL: cfg.BaseURL, Webhooks: whNotifier, } - analyticsH := &handler.AnalyticsHandler{DB: dbPool} - qgH := &handler.QualityGatesHandler{DB: dbPool, AuditStore: auditStore} + if dbPool != nil { + execH.ExecStore = store.NewExecutionsStore(dbPool) + } + analyticsH := &handler.AnalyticsHandler{} + if dbPool != nil { + analyticsH.AnalyticsStore = store.NewAnalyticsStore(dbPool) + } + qgH := &handler.QualityGatesHandler{AuditStore: auditStore} if qgStore != nil { qgH.Store = qgStore } - teamsH := &handler.TeamsHandler{DB: dbPool, AuditStore: auditStore} + if dbPool != nil { + qgH.ReportStore = store.NewReportsStore(dbPool) + } + teamsH := &handler.TeamsHandler{AuditStore: auditStore} if dbPool != nil { teamsH.Store = store.NewTeamsStore(dbPool) } shardH := &handler.ShardingHandler{DurationStore: durStore} - adminH := &handler.AdminHandler{AuditStore: auditStore, DB: dbPool} + adminH := &handler.AdminHandler{AuditStore: auditStore} + if dbPool != nil { + adminH.AdminStore = store.NewAdminStore(dbPool) + } whH := &handler.WebhooksHandler{Dispatcher: webhook.NewDispatcher(), AuditStore: auditStore} if whStore != nil { whH.Store = whStore @@ -210,7 +226,6 @@ func NewRouter(cfg *config.Config, pool ...*db.Pool) http.Handler { } invH := &handler.InvitationsHandler{ - DB: dbPool, BaseURL: cfg.BaseURL, Mailer: mailer.New(cfg.SMTPHost, cfg.SMTPPort, cfg.SMTPUser, cfg.SMTPPass, cfg.SMTPFrom), AuditStore: auditStore, @@ -317,7 +332,6 @@ func NewRouter(cfg *config.Config, pool ...*db.Pool) http.Handler { }) }) - r.Route("/sharding", func(r chi.Router) { r.Post("/plan", shardH.CreatePlan) r.Post("/rebalance", shardH.Rebalance) @@ -349,7 +363,6 @@ func NewRouter(cfg *config.Config, pool ...*db.Pool) http.Handler { return r } - func zerologMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Debug(). diff --git a/internal/server/routes_test.go b/internal/server/routes_test.go index b342e5bb..e4a0a485 100644 --- a/internal/server/routes_test.go +++ b/internal/server/routes_test.go @@ -151,7 +151,7 @@ func TestAuthenticatedEndpointsWithToken(t *testing.T) { path string wantStatus int }{ - {"GET", "/api/v1/reports", http.StatusServiceUnavailable}, // no DB configured + {"GET", "/api/v1/reports", http.StatusServiceUnavailable}, // no DB configured {"GET", "/api/v1/executions", http.StatusServiceUnavailable}, // no DB configured {"GET", "/api/v1/analytics/trends", http.StatusServiceUnavailable}, // no DB configured {"GET", "/api/v1/analytics/flaky-tests", http.StatusServiceUnavailable}, // no DB configured @@ -205,8 +205,8 @@ func TestCTRFReportIngestion(t *testing.T) { w := httptest.NewRecorder() router.ServeHTTP(w, req) - if w.Code != http.StatusCreated { - t.Errorf("POST /api/v1/reports status = %d, want %d (body: %s)", w.Code, http.StatusCreated, w.Body.String()) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("POST /api/v1/reports status = %d, want %d (no DB configured)", w.Code, http.StatusServiceUnavailable) } } @@ -433,9 +433,9 @@ func TestMaintainerCanCreateReport(t *testing.T) { w := httptest.NewRecorder() router.ServeHTTP(w, req) - // Without DB, should get 201 (no-DB fallback), not 403 - if w.Code != http.StatusCreated { - t.Errorf("maintainer POST /api/v1/reports: status = %d, want %d (body: %s)", w.Code, http.StatusCreated, w.Body.String()) + // Without DB, should get 503 since the no-DB fallback path was removed + if w.Code != http.StatusServiceUnavailable { + t.Errorf("maintainer POST /api/v1/reports: status = %d, want %d (no DB configured)", w.Code, http.StatusServiceUnavailable) } } diff --git a/internal/server/team_access_test.go b/internal/server/team_access_test.go index 55457ddf..f24e4f09 100644 --- a/internal/server/team_access_test.go +++ b/internal/server/team_access_test.go @@ -366,9 +366,9 @@ func TestTeamIsolation_CTRFIngestionAcceptsBothTeams(t *testing.T) { w := httptest.NewRecorder() router.ServeHTTP(w, req) - if w.Code != http.StatusCreated { - t.Errorf("POST /api/v1/reports for %s: status = %d, want 201 (body: %s)", - tc.name, w.Code, w.Body.String()) + if w.Code != http.StatusServiceUnavailable { + t.Errorf("POST /api/v1/reports for %s: status = %d, want 503 (no DB configured)", + tc.name, w.Code) } }) } diff --git a/internal/store/admin.go b/internal/store/admin.go new file mode 100644 index 00000000..9eacd39e --- /dev/null +++ b/internal/store/admin.go @@ -0,0 +1,50 @@ +package store + +import ( + "context" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/scaledtest/scaledtest/internal/model" +) + +type AdminStore struct { + pool *pgxpool.Pool +} + +func NewAdminStore(pool *pgxpool.Pool) *AdminStore { + return &AdminStore{pool: pool} +} + +func (s *AdminStore) ListUsers(ctx context.Context, limit, offset int) ([]model.User, int, error) { + rows, err := s.pool.Query(ctx, + `SELECT id, email, display_name, role, created_at, updated_at + FROM users + ORDER BY created_at DESC + LIMIT $1 OFFSET $2`, + limit, offset) + if err != nil { + return nil, 0, err + } + defer rows.Close() + + var users []model.User + for rows.Next() { + var u model.User + if err := rows.Scan(&u.ID, &u.Email, &u.DisplayName, &u.Role, &u.CreatedAt, &u.UpdatedAt); err != nil { + return nil, 0, err + } + users = append(users, u) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + + var total int + err = s.pool.QueryRow(ctx, `SELECT COUNT(*) FROM users`).Scan(&total) + if err != nil { + return nil, 0, err + } + + return users, total, nil +} diff --git a/internal/store/analytics.go b/internal/store/analytics.go new file mode 100644 index 00000000..0f254470 --- /dev/null +++ b/internal/store/analytics.go @@ -0,0 +1,169 @@ +package store + +import ( + "context" + "time" + + "github.com/jackc/pgx/v5/pgxpool" +) + +type AnalyticsStore struct { + pool *pgxpool.Pool +} + +func NewAnalyticsStore(pool *pgxpool.Pool) *AnalyticsStore { + return &AnalyticsStore{pool: pool} +} + +type TrendRow struct { + Date time.Time + Total int + Passed int + Failed int + Skipped int + PassRate float64 +} + +func (s *AnalyticsStore) QueryTrends(ctx context.Context, groupBy, teamID string, start, end time.Time) ([]TrendRow, error) { + query := ` + SELECT + time_bucket($1::interval, created_at) AS bucket, + count(*) AS total, + count(*) FILTER (WHERE status = 'passed') AS passed, + count(*) FILTER (WHERE status = 'failed') AS failed, + count(*) FILTER (WHERE status = 'skipped') AS skipped, + CASE WHEN count(*) > 0 + THEN round(count(*) FILTER (WHERE status = 'passed')::numeric / count(*)::numeric * 100, 2) + ELSE 0 + END AS pass_rate + FROM test_results + WHERE team_id = $2 + AND created_at >= $3 + AND created_at <= $4 + GROUP BY bucket + ORDER BY bucket + ` + rows, err := s.pool.Query(ctx, query, groupBy, teamID, start, end) + if err != nil { + return nil, err + } + defer rows.Close() + + var results []TrendRow + for rows.Next() { + var tr TrendRow + if err := rows.Scan(&tr.Date, &tr.Total, &tr.Passed, &tr.Failed, &tr.Skipped, &tr.PassRate); err != nil { + return nil, err + } + results = append(results, tr) + } + return results, rows.Err() +} + +func (s *AnalyticsStore) QueryDurationBuckets(ctx context.Context, teamID string, start, end time.Time) ([]int64, error) { + rows, err := s.pool.Query(ctx, + `SELECT duration_ms + FROM test_results + WHERE team_id = $1 AND created_at >= $2 AND created_at <= $3`, + teamID, start, end) + if err != nil { + return nil, err + } + defer rows.Close() + + var durations []int64 + for rows.Next() { + var ms int64 + if err := rows.Scan(&ms); err != nil { + return nil, err + } + durations = append(durations, ms) + } + return durations, rows.Err() +} + +type ErrorClusterRow struct { + Message string + Count int + TestNames []string + FirstSeen time.Time + LastSeen time.Time +} + +func (s *AnalyticsStore) QueryErrorClusters(ctx context.Context, teamID string, start, end time.Time, limit int) ([]ErrorClusterRow, error) { + query := ` + SELECT + message, + count(*) AS count, + array_agg(DISTINCT name) AS test_names, + min(created_at) AS first_seen, + max(created_at) AS last_seen + FROM test_results + WHERE team_id = $1 + AND status = 'failed' + AND message IS NOT NULL + AND message != '' + AND created_at >= $2 + AND created_at <= $3 + GROUP BY message + ORDER BY count DESC + LIMIT $4 + ` + rows, err := s.pool.Query(ctx, query, teamID, start, end, limit) + if err != nil { + return nil, err + } + defer rows.Close() + + var clusters []ErrorClusterRow + for rows.Next() { + var ec ErrorClusterRow + if err := rows.Scan(&ec.Message, &ec.Count, &ec.TestNames, &ec.FirstSeen, &ec.LastSeen); err != nil { + return nil, err + } + clusters = append(clusters, ec) + } + return clusters, rows.Err() +} + +type FlakyRow struct { + Name string + Suite string + FilePath string + Statuses []string + LastStatus string + TotalRuns int +} + +func (s *AnalyticsStore) QueryFlakyTests(ctx context.Context, teamID string, cutoff time.Time, minRuns int) ([]FlakyRow, error) { + query := ` + SELECT + name, + COALESCE(suite, '') AS suite, + COALESCE(file_path, '') AS file_path, + array_agg(status ORDER BY created_at) AS statuses, + (array_agg(status ORDER BY created_at DESC))[1] AS last_status, + count(*) AS total_runs + FROM test_results + WHERE team_id = $1 + AND created_at >= $2 + GROUP BY name, suite, file_path + HAVING count(*) >= $3 + ORDER BY name + ` + rows, err := s.pool.Query(ctx, query, teamID, cutoff, minRuns) + if err != nil { + return nil, err + } + defer rows.Close() + + var results []FlakyRow + for rows.Next() { + var fr FlakyRow + if err := rows.Scan(&fr.Name, &fr.Suite, &fr.FilePath, &fr.Statuses, &fr.LastStatus, &fr.TotalRuns); err != nil { + return nil, err + } + results = append(results, fr) + } + return results, rows.Err() +} diff --git a/internal/store/auth.go b/internal/store/auth.go new file mode 100644 index 00000000..f4303564 --- /dev/null +++ b/internal/store/auth.go @@ -0,0 +1,144 @@ +package store + +import ( + "context" + "net" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/scaledtest/scaledtest/internal/model" +) + +type AuthStore struct { + pool *pgxpool.Pool +} + +func NewAuthStore(pool *pgxpool.Pool) *AuthStore { + return &AuthStore{pool: pool} +} + +func (s *AuthStore) GetUserByEmail(ctx context.Context, email string) (*model.User, error) { + var u model.User + err := s.pool.QueryRow(ctx, + `SELECT id, email, password_hash, display_name, role FROM users WHERE email = $1`, + email, + ).Scan(&u.ID, &u.Email, &u.PasswordHash, &u.DisplayName, &u.Role) + if err != nil { + return nil, err + } + return &u, nil +} + +func (s *AuthStore) GetUserByID(ctx context.Context, id string) (*model.User, error) { + var u model.User + err := s.pool.QueryRow(ctx, + `SELECT id, email, password_hash, display_name, role FROM users WHERE id = $1`, + id, + ).Scan(&u.ID, &u.Email, &u.PasswordHash, &u.DisplayName, &u.Role) + if err != nil { + return nil, err + } + return &u, nil +} + +func (s *AuthStore) EmailExists(ctx context.Context, email string) (bool, error) { + var exists bool + err := s.pool.QueryRow(ctx, + `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)`, email).Scan(&exists) + return exists, err +} + +func (s *AuthStore) CreateUser(ctx context.Context, email, passwordHash, displayName, role string) (userID, returnedRole string, err error) { + err = s.pool.QueryRow(ctx, + `INSERT INTO users (email, password_hash, display_name, role) + SELECT $1, $2, $3, + CASE WHEN NOT EXISTS (SELECT 1 FROM users) THEN 'owner'::text ELSE 'maintainer'::text END + RETURNING id, role`, + email, passwordHash, displayName, + ).Scan(&userID, &returnedRole) + return +} + +func (s *AuthStore) CreateUserWithRole(ctx context.Context, email, passwordHash, displayName, role string) (userID string, err error) { + err = s.pool.QueryRow(ctx, + `INSERT INTO users (email, password_hash, display_name, role) + VALUES ($1, $2, $3, $4) + RETURNING id`, + email, passwordHash, displayName, role, + ).Scan(&userID) + return +} + +func (s *AuthStore) UpdatePassword(ctx context.Context, userID, passwordHash string) (int64, error) { + tag, err := s.pool.Exec(ctx, + `UPDATE users SET password_hash = $1 WHERE id = $2`, + passwordHash, userID, + ) + return tag.RowsAffected(), err +} + +func (s *AuthStore) UpdateProfile(ctx context.Context, userID, displayName string) (*model.User, error) { + var u model.User + err := s.pool.QueryRow(ctx, + `UPDATE users SET display_name = $1, updated_at = now() + WHERE id = $2 + RETURNING id, email, display_name, role`, + displayName, userID, + ).Scan(&u.ID, &u.Email, &u.DisplayName, &u.Role) + if err != nil { + return nil, err + } + return &u, nil +} + +func (s *AuthStore) GetPrimaryTeamID(ctx context.Context, userID string) (string, error) { + var teamID string + err := s.pool.QueryRow(ctx, + `SELECT team_id FROM user_teams WHERE user_id = $1 ORDER BY joined_at ASC LIMIT 1`, + userID, + ).Scan(&teamID) + if err != nil { + return "", err + } + return teamID, nil +} + +func (s *AuthStore) CreateSession(ctx context.Context, userID, refreshToken string, userAgent string, ipAddr net.IP, expiresAt time.Time) error { + _, err := s.pool.Exec(ctx, + `INSERT INTO sessions (user_id, refresh_token, user_agent, ip_address, expires_at) + VALUES ($1, $2, $3, $4, $5)`, + userID, refreshToken, userAgent, ipAddr, expiresAt, + ) + return err +} + +type SessionInfo struct { + ID string + UserID string + ExpiresAt time.Time +} + +func (s *AuthStore) GetSessionByRefreshToken(ctx context.Context, refreshToken string) (*SessionInfo, error) { + var si SessionInfo + err := s.pool.QueryRow(ctx, + `SELECT s.id, s.user_id, s.expires_at + FROM sessions s + WHERE s.refresh_token = $1`, + refreshToken, + ).Scan(&si.ID, &si.UserID, &si.ExpiresAt) + if err != nil { + return nil, err + } + return &si, nil +} + +func (s *AuthStore) DeleteSession(ctx context.Context, sessionID string) error { + _, err := s.pool.Exec(ctx, `DELETE FROM sessions WHERE id = $1`, sessionID) + return err +} + +func (s *AuthStore) DeleteSessionByRefreshToken(ctx context.Context, refreshToken string) error { + _, err := s.pool.Exec(ctx, `DELETE FROM sessions WHERE refresh_token = $1`, refreshToken) + return err +} diff --git a/internal/store/executions.go b/internal/store/executions.go new file mode 100644 index 00000000..98957edf --- /dev/null +++ b/internal/store/executions.go @@ -0,0 +1,159 @@ +package store + +import ( + "context" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/scaledtest/scaledtest/internal/model" +) + +type ExecutionsStore struct { + pool *pgxpool.Pool +} + +func NewExecutionsStore(pool *pgxpool.Pool) *ExecutionsStore { + return &ExecutionsStore{pool: pool} +} + +func (s *ExecutionsStore) List(ctx context.Context, teamID string, limit, offset int) ([]model.TestExecution, int, error) { + rows, err := s.pool.Query(ctx, + `SELECT id, team_id, status, command, config, report_id, k8s_job_name, k8s_pod_name, + error_msg, started_at, finished_at, created_at, updated_at + FROM test_executions + WHERE team_id = $1 + ORDER BY created_at DESC + LIMIT $2 OFFSET $3`, + teamID, limit, offset) + if err != nil { + return nil, 0, err + } + defer rows.Close() + + var executions []model.TestExecution + for rows.Next() { + var e model.TestExecution + if err := rows.Scan( + &e.ID, &e.TeamID, &e.Status, &e.Command, &e.Config, &e.ReportID, + &e.K8sJobName, &e.K8sPodName, &e.ErrorMsg, &e.StartedAt, + &e.FinishedAt, &e.CreatedAt, &e.UpdatedAt, + ); err != nil { + return nil, 0, err + } + executions = append(executions, e) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + + var total int + err = s.pool.QueryRow(ctx, + `SELECT COUNT(*) FROM test_executions WHERE team_id = $1`, + teamID).Scan(&total) + if err != nil { + return nil, 0, err + } + + return executions, total, nil +} + +func (s *ExecutionsStore) Create(ctx context.Context, teamID, command string, configJSON []byte) (string, error) { + id := uuid.New().String() + now := time.Now() + _, err := s.pool.Exec(ctx, + `INSERT INTO test_executions (id, team_id, status, command, config, created_at, updated_at) + VALUES ($1, $2, 'pending', $3, $4, $5, $5)`, + id, teamID, command, configJSON, now) + return id, err +} + +func (s *ExecutionsStore) Get(ctx context.Context, id, teamID string) (*model.TestExecution, error) { + var e model.TestExecution + err := s.pool.QueryRow(ctx, + `SELECT id, team_id, status, command, config, report_id, k8s_job_name, k8s_pod_name, + error_msg, started_at, finished_at, created_at, updated_at + FROM test_executions + WHERE id = $1 AND team_id = $2`, + id, teamID).Scan( + &e.ID, &e.TeamID, &e.Status, &e.Command, &e.Config, &e.ReportID, + &e.K8sJobName, &e.K8sPodName, &e.ErrorMsg, &e.StartedAt, + &e.FinishedAt, &e.CreatedAt, &e.UpdatedAt) + if err != nil { + return nil, err + } + return &e, nil +} + +func (s *ExecutionsStore) Cancel(ctx context.Context, id, teamID string, now time.Time) (int64, error) { + tag, err := s.pool.Exec(ctx, + `UPDATE test_executions + SET status = 'cancelled', finished_at = $1, updated_at = $1 + WHERE id = $2 AND team_id = $3 AND status IN ('pending', 'running')`, + now, id, teamID) + return tag.RowsAffected(), err +} + +func (s *ExecutionsStore) UpdateStatus(ctx context.Context, id, teamID, status string, now time.Time, errorMsg *string) (int64, error) { + query := `UPDATE test_executions SET status = $1, updated_at = $2` + args := []interface{}{status, now} + argIdx := 3 + + if status == "running" { + query += `, started_at = COALESCE(started_at, $3)` + args = append(args, now) + argIdx++ + } + + if status == "completed" || status == "failed" || status == "cancelled" { + query += `, finished_at = $` + strconv.Itoa(argIdx) + args = append(args, now) + argIdx++ + } + + if errorMsg != nil && *errorMsg != "" { + query += `, error_msg = $` + strconv.Itoa(argIdx) + args = append(args, *errorMsg) + argIdx++ + } + + query += ` WHERE id = $` + strconv.Itoa(argIdx) + ` AND team_id = $` + strconv.Itoa(argIdx+1) + args = append(args, id, teamID) + + tag, err := s.pool.Exec(ctx, query, args...) + return tag.RowsAffected(), err +} + +func (s *ExecutionsStore) Exists(ctx context.Context, id, teamID string) (bool, error) { + var exists bool + err := s.pool.QueryRow(ctx, + `SELECT EXISTS(SELECT 1 FROM test_executions WHERE id = $1 AND team_id = $2)`, + id, teamID).Scan(&exists) + return exists, err +} + +func (s *ExecutionsStore) GetK8sJobName(ctx context.Context, id string) (*string, error) { + var jobName *string + err := s.pool.QueryRow(ctx, + `SELECT k8s_job_name FROM test_executions WHERE id = $1`, id).Scan(&jobName) + if err != nil { + return nil, err + } + return jobName, nil +} + +func (s *ExecutionsStore) SetK8sJobName(ctx context.Context, id, jobName string, now time.Time) error { + _, err := s.pool.Exec(ctx, + `UPDATE test_executions SET k8s_job_name = $1, updated_at = $2 WHERE id = $3`, + jobName, now, id) + return err +} + +func (s *ExecutionsStore) MarkFailed(ctx context.Context, id, errorMsg string, now time.Time) error { + _, err := s.pool.Exec(ctx, + `UPDATE test_executions SET status = 'failed', error_msg = $1, updated_at = $2 WHERE id = $3`, + errorMsg, now, id) + return err +} diff --git a/internal/store/helpers.go b/internal/store/helpers.go new file mode 100644 index 00000000..3b93429d --- /dev/null +++ b/internal/store/helpers.go @@ -0,0 +1,10 @@ +package store + +// NullString returns a *string that is nil for empty strings. +// This is the shared helper for converting empty strings to NULL for database inserts. +func NullString(s string) *string { + if s == "" { + return nil + } + return &s +} diff --git a/internal/store/invitations.go b/internal/store/invitations.go index 6270849c..af43d04c 100644 --- a/internal/store/invitations.go +++ b/internal/store/invitations.go @@ -147,3 +147,13 @@ func (s *InvitationStore) Delete(ctx context.Context, teamID, id string) error { } return nil } + +// GetTeamName returns the name of the team with the given ID. +func (s *InvitationStore) GetTeamName(ctx context.Context, teamID string) (string, error) { + var name string + err := s.pool.QueryRow(ctx, `SELECT name FROM teams WHERE id = $1`, teamID).Scan(&name) + if err != nil { + return "", fmt.Errorf("get team name: %w", err) + } + return name, nil +} diff --git a/internal/store/oauth.go b/internal/store/oauth.go new file mode 100644 index 00000000..25736473 --- /dev/null +++ b/internal/store/oauth.go @@ -0,0 +1,91 @@ +package store + +import ( + "context" + "net" + "time" + + "github.com/jackc/pgx/v5/pgxpool" +) + +type OAuthStore struct { + pool *pgxpool.Pool +} + +func NewOAuthStore(pool *pgxpool.Pool) *OAuthStore { + return &OAuthStore{pool: pool} +} + +type OAuthLinkedUser struct { + ID string + Email string + DisplayName string + Role string +} + +func (s *OAuthStore) FindLinkedUser(ctx context.Context, provider, providerID string) (*OAuthLinkedUser, error) { + var u OAuthLinkedUser + err := s.pool.QueryRow(ctx, + `SELECT u.id, u.email, u.display_name, u.role + FROM oauth_accounts oa + JOIN users u ON u.id = oa.user_id + WHERE oa.provider = $1 AND oa.provider_id = $2`, + provider, providerID, + ).Scan(&u.ID, &u.Email, &u.DisplayName, &u.Role) + if err != nil { + return nil, err + } + return &u, nil +} + +func (s *OAuthStore) FindUserByEmail(ctx context.Context, email string) (*OAuthLinkedUser, error) { + var u OAuthLinkedUser + err := s.pool.QueryRow(ctx, + `SELECT id, email, display_name, role FROM users WHERE email = $1`, email, + ).Scan(&u.ID, &u.Email, &u.DisplayName, &u.Role) + if err != nil { + return nil, err + } + return &u, nil +} + +func (s *OAuthStore) CreateUser(ctx context.Context, email, displayName string) (string, string, error) { + var userID, role string + err := s.pool.QueryRow(ctx, + `INSERT INTO users (email, password_hash, display_name) + VALUES ($1, '', $2) + RETURNING id, role`, + email, displayName, + ).Scan(&userID, &role) + if err != nil { + return "", "", err + } + return userID, role, nil +} + +func (s *OAuthStore) LinkAccount(ctx context.Context, userID, provider, providerID, accessToken, refreshToken string) error { + _, err := s.pool.Exec(ctx, + `INSERT INTO oauth_accounts (user_id, provider, provider_id, access_token, refresh_token) + VALUES ($1, $2, $3, $4, $5)`, + userID, provider, providerID, accessToken, refreshToken, + ) + return err +} + +func (s *OAuthStore) UpdateTokens(ctx context.Context, accessToken, refreshToken, provider, providerID string) error { + _, err := s.pool.Exec(ctx, + `UPDATE oauth_accounts SET access_token = $1, refresh_token = $2 + WHERE provider = $3 AND provider_id = $4`, + accessToken, refreshToken, provider, providerID, + ) + return err +} + +func (s *OAuthStore) CreateSession(ctx context.Context, userID, refreshToken, userAgent string, ipAddr net.IP, expiresAt time.Time) error { + _, err := s.pool.Exec(ctx, + `INSERT INTO sessions (user_id, refresh_token, user_agent, ip_address, expires_at) + VALUES ($1, $2, $3, $4, $5)`, + userID, refreshToken, userAgent, ipAddr, expiresAt, + ) + return err +} diff --git a/internal/store/reports.go b/internal/store/reports.go new file mode 100644 index 00000000..5e3288c3 --- /dev/null +++ b/internal/store/reports.go @@ -0,0 +1,290 @@ +package store + +import ( + "context" + "encoding/json" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/scaledtest/scaledtest/internal/model" +) + +type ReportsStore struct { + pool *pgxpool.Pool +} + +func NewReportsStore(pool *pgxpool.Pool) *ReportsStore { + return &ReportsStore{pool: pool} +} + +type ReportListFilter struct { + TeamID string + Since *time.Time + Until *time.Time + Limit int + Offset int +} + +func (s *ReportsStore) List(ctx context.Context, f ReportListFilter) ([]map[string]interface{}, int, error) { + whereClause := ` WHERE team_id = $1` + args := []interface{}{f.TeamID} + argIdx := 2 + + if f.Since != nil { + whereClause += ` AND created_at >= $` + strconv.Itoa(argIdx) + args = append(args, *f.Since) + argIdx++ + } + if f.Until != nil { + whereClause += ` AND created_at <= $` + strconv.Itoa(argIdx) + args = append(args, *f.Until) + argIdx++ + } + + countQuery := `SELECT COUNT(*) FROM test_reports` + whereClause + var total int + if err := s.pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, 0, err + } + + query := `SELECT id, team_id, execution_id, tool_name, tool_version, environment, summary, created_at + FROM test_reports` + whereClause + + ` ORDER BY created_at DESC LIMIT $` + strconv.Itoa(argIdx) + ` OFFSET $` + strconv.Itoa(argIdx+1) + dataArgs := append(args, f.Limit, f.Offset) + + rows, err := s.pool.Query(ctx, query, dataArgs...) + if err != nil { + return nil, 0, err + } + defer rows.Close() + + var reports []map[string]interface{} + for rows.Next() { + var rpt model.TestReport + if err := rows.Scan( + &rpt.ID, &rpt.TeamID, &rpt.ExecutionID, &rpt.ToolName, + &rpt.ToolVersion, &rpt.Environment, &rpt.Summary, &rpt.CreatedAt, + ); err != nil { + return nil, 0, err + } + reports = append(reports, flattenReport(rpt)) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + + return reports, total, nil +} + +func flattenReport(rpt model.TestReport) map[string]interface{} { + out := map[string]interface{}{ + "id": rpt.ID, + "team_id": rpt.TeamID, + "tool_name": rpt.ToolName, + "summary": rpt.Summary, + "created_at": rpt.CreatedAt, + } + if rpt.ToolVersion != "" { + out["tool_version"] = rpt.ToolVersion + } + if rpt.ExecutionID != nil { + out["execution_id"] = *rpt.ExecutionID + } + if len(rpt.Environment) > 0 { + out["environment"] = rpt.Environment + } + + var s model.ReportSummary + if err := json.Unmarshal(rpt.Summary, &s); err == nil { + out["test_count"] = s.Tests + out["passed"] = s.Passed + out["failed"] = s.Failed + out["skipped"] = s.Skipped + out["pending"] = s.Pending + } + return out +} + +type CreateReportParams struct { + ID string + TeamID string + ExecutionID *string + ToolName string + ToolVersion string + Environment json.RawMessage + Summary json.RawMessage + Raw json.RawMessage + CreatedAt time.Time + TriageGitHubStatus bool +} + +func (s *ReportsStore) CreateWithResults(ctx context.Context, p CreateReportParams, results []model.TestResult) error { + tx, err := s.pool.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, + `INSERT INTO test_reports (id, team_id, execution_id, tool_name, tool_version, environment, summary, raw, created_at, triage_github_status) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`, + p.ID, p.TeamID, p.ExecutionID, + p.ToolName, p.ToolVersion, + p.Environment, p.Summary, p.Raw, p.CreatedAt, + p.TriageGitHubStatus) + if err != nil { + return err + } + + batch := &pgx.Batch{} + for _, res := range results { + resID := uuid.New().String() + batch.Queue( + `INSERT INTO test_results (id, report_id, team_id, name, status, duration_ms, message, trace, file_path, suite, tags, retry, flaky, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)`, + resID, res.ReportID, res.TeamID, res.Name, res.Status, + res.DurationMs, NullString(res.Message), NullString(res.Trace), + NullString(res.FilePath), NullString(res.Suite), + res.Tags, res.Retry, res.Flaky, p.CreatedAt, + ) + } + br := tx.SendBatch(ctx, batch) + for range results { + if _, err := br.Exec(); err != nil { + br.Close() + return err + } + } + br.Close() + + if p.ExecutionID != nil { + tag, err := tx.Exec(ctx, + `UPDATE test_executions SET report_id = $1, updated_at = $2 + WHERE id = $3 AND team_id = $4`, + p.ID, p.CreatedAt, *p.ExecutionID, p.TeamID) + if err != nil { + return err + } + if tag.RowsAffected() == 0 { + return pgx.ErrNoRows + } + } + + return tx.Commit(ctx) +} + +func (s *ReportsStore) Get(ctx context.Context, id, teamID string) (*model.TestReport, error) { + var rpt model.TestReport + err := s.pool.QueryRow(ctx, + `SELECT id, team_id, execution_id, tool_name, tool_version, environment, summary, created_at + FROM test_reports + WHERE id = $1 AND team_id = $2`, + id, teamID).Scan( + &rpt.ID, &rpt.TeamID, &rpt.ExecutionID, &rpt.ToolName, + &rpt.ToolVersion, &rpt.Environment, &rpt.Summary, &rpt.CreatedAt, + ) + if err != nil { + return nil, err + } + return &rpt, nil +} + +func (s *ReportsStore) Delete(ctx context.Context, id, teamID string) (int64, error) { + tag, err := s.pool.Exec(ctx, + `DELETE FROM test_reports WHERE id = $1 AND team_id = $2`, + id, teamID) + return tag.RowsAffected(), err +} + +func (s *ReportsStore) ExecutionExists(ctx context.Context, executionID, teamID string) (bool, error) { + var exists bool + err := s.pool.QueryRow(ctx, + `SELECT EXISTS(SELECT 1 FROM test_executions WHERE id = $1 AND team_id = $2)`, + executionID, teamID).Scan(&exists) + return exists, err +} + +func (s *ReportsStore) GetReportAndResults(ctx context.Context, id, teamID string) (*model.TestReport, map[string]*model.TestResult, error) { + var rpt model.TestReport + err := s.pool.QueryRow(ctx, + `SELECT id, team_id, execution_id, tool_name, tool_version, summary, created_at + FROM test_reports WHERE id = $1 AND team_id = $2`, + id, teamID).Scan( + &rpt.ID, &rpt.TeamID, &rpt.ExecutionID, &rpt.ToolName, + &rpt.ToolVersion, &rpt.Summary, &rpt.CreatedAt, + ) + if err != nil { + return nil, nil, err + } + + rows, err := s.pool.Query(ctx, + `SELECT id, report_id, team_id, name, status, duration_ms, + COALESCE(message, ''), COALESCE(trace, ''), COALESCE(file_path, ''), COALESCE(suite, ''), + tags, retry, flaky, created_at + FROM test_results WHERE report_id = $1 AND team_id = $2`, + id, teamID) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + results := make(map[string]*model.TestResult) + for rows.Next() { + var res model.TestResult + if err := rows.Scan( + &res.ID, &res.ReportID, &res.TeamID, &res.Name, &res.Status, + &res.DurationMs, &res.Message, &res.Trace, &res.FilePath, + &res.Suite, &res.Tags, &res.Retry, &res.Flaky, &res.CreatedAt, + ); err != nil { + return nil, nil, err + } + results[res.Name] = &res + } + if err := rows.Err(); err != nil { + return nil, nil, err + } + + return &rpt, results, nil +} + +func (s *ReportsStore) GetPreviousFailedTests(ctx context.Context, teamID, currentReportID string) (map[string]bool, error) { + var prevReportID string + err := s.pool.QueryRow(ctx, + `SELECT id FROM test_reports WHERE team_id = $1 AND id != $2 ORDER BY created_at DESC LIMIT 1`, + teamID, currentReportID, + ).Scan(&prevReportID) + if err != nil { + if err == pgx.ErrNoRows { + return nil, nil + } + return nil, err + } + + rows, err := s.pool.Query(ctx, + `SELECT name FROM test_results WHERE report_id = $1 AND status = 'failed'`, + prevReportID) + if err != nil { + return nil, err + } + defer rows.Close() + + failed := make(map[string]bool) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + failed[name] = true + } + if err := rows.Err(); err != nil { + return nil, err + } + if len(failed) == 0 { + return nil, nil + } + return failed, nil +} diff --git a/internal/store/reports_store_integration_test.go b/internal/store/reports_store_integration_test.go new file mode 100644 index 00000000..f17d1ab0 --- /dev/null +++ b/internal/store/reports_store_integration_test.go @@ -0,0 +1,122 @@ +//go:build integration + +package store_test + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/scaledtest/scaledtest/internal/integration" + "github.com/scaledtest/scaledtest/internal/model" + "github.com/scaledtest/scaledtest/internal/store" +) + +func TestReportsStore_CreateWithResults_BulkInsert(t *testing.T) { + tdb := integration.Setup(t) + ctx := context.Background() + teamID := tdb.CreateTeam(t, "reports-bulk-test-team") + s := store.NewReportsStore(tdb.Pool) + + summary, _ := json.Marshal(map[string]int{"tests": 100, "passed": 90, "failed": 8, "skipped": 2}) + raw, _ := json.Marshal(map[string]interface{}{"results": map[string]interface{}{"tool": map[string]interface{}{"name": "jest"}}}) + + params := store.CreateReportParams{ + ID: "rpt-bulk-001", + TeamID: teamID, + ToolName: "jest", + ToolVersion: "1.0.0", + Summary: summary, + Raw: raw, + CreatedAt: time.Now(), + } + + results := make([]model.TestResult, 100) + for i := range results { + status := "passed" + if i%10 == 0 { + status = "failed" + } + if i%25 == 0 { + status = "skipped" + } + results[i] = model.TestResult{ + ReportID: "rpt-bulk-001", + TeamID: teamID, + Name: "test-bulk-" + string(rune('A'+i%26)) + string(rune('0'+i%10)), + Status: status, + DurationMs: int64(i * 10), + } + } + + err := s.CreateWithResults(ctx, params, results) + if err != nil { + t.Fatalf("CreateWithResults with 100 results: %v", err) + } + + rpt, found, err := s.GetReportAndResults(ctx, "rpt-bulk-001", teamID) + if err != nil { + t.Fatalf("GetReportAndResults: %v", err) + } + if rpt == nil { + t.Fatal("expected report, got nil") + } + if len(found) != 100 { + t.Errorf("expected 100 results, got %d", len(found)) + } +} + +func TestReportsStore_CreateWithResults_BulkInsert_1000Results(t *testing.T) { + tdb := integration.Setup(t) + ctx := context.Background() + teamID := tdb.CreateTeam(t, "reports-bulk-1k-test-team") + s := store.NewReportsStore(tdb.Pool) + + summary, _ := json.Marshal(map[string]int{"tests": 1000, "passed": 900, "failed": 50, "skipped": 50}) + raw, _ := json.Marshal(map[string]interface{}{"results": map[string]interface{}{"tool": map[string]interface{}{"name": "jest"}}}) + + params := store.CreateReportParams{ + ID: "rpt-bulk-1k", + TeamID: teamID, + ToolName: "jest", + ToolVersion: "1.0.0", + Summary: summary, + Raw: raw, + CreatedAt: time.Now(), + } + + results := make([]model.TestResult, 1000) + for i := range results { + status := "passed" + if i%20 == 0 { + status = "failed" + } + if i%50 == 0 { + status = "skipped" + } + results[i] = model.TestResult{ + ReportID: "rpt-bulk-1k", + TeamID: teamID, + Name: "test-1k-" + string(rune('A'+i%26)) + string(rune('0'+i%10)) + string(rune('0'+i/10%10)), + Status: status, + DurationMs: int64(i * 5), + } + } + + err := s.CreateWithResults(ctx, params, results) + if err != nil { + t.Fatalf("CreateWithResults with 1000 results: %v", err) + } + + rpt, found, err := s.GetReportAndResults(ctx, "rpt-bulk-1k", teamID) + if err != nil { + t.Fatalf("GetReportAndResults: %v", err) + } + if rpt == nil { + t.Fatal("expected report, got nil") + } + if len(found) != 1000 { + t.Errorf("expected 1000 results, got %d", len(found)) + } +} diff --git a/internal/store/teams.go b/internal/store/teams.go index 49bc3986..ce662b9d 100644 --- a/internal/store/teams.go +++ b/internal/store/teams.go @@ -8,6 +8,12 @@ import ( "github.com/scaledtest/scaledtest/internal/model" ) +// TeamWithRole is a team paired with the current user's role in that team. +type TeamWithRole struct { + model.Team + Role string `json:"role"` +} + // TeamsStore handles team and API token persistence. type TeamsStore struct { pool *pgxpool.Pool @@ -18,6 +24,45 @@ func NewTeamsStore(pool *pgxpool.Pool) *TeamsStore { return &TeamsStore{pool: pool} } +// ListTeams returns all teams for a user with their role in each team. +func (s *TeamsStore) ListTeams(ctx context.Context, userID string) ([]TeamWithRole, error) { + rows, err := s.pool.Query(ctx, + `SELECT t.id, t.name, t.created_at, ut.role + FROM teams t + JOIN user_teams ut ON ut.team_id = t.id + WHERE ut.user_id = $1 + ORDER BY t.name`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + + var teams []TeamWithRole + for rows.Next() { + var t TeamWithRole + if err := rows.Scan(&t.ID, &t.Name, &t.CreatedAt, &t.Role); err != nil { + return nil, err + } + teams = append(teams, t) + } + return teams, nil +} + +// GetTeam returns a team with the user's role, or pgx.ErrNoRows if not found or not a member. +func (s *TeamsStore) GetTeam(ctx context.Context, teamID, userID string) (*TeamWithRole, error) { + var t TeamWithRole + err := s.pool.QueryRow(ctx, + `SELECT t.id, t.name, t.created_at, ut.role + FROM teams t + JOIN user_teams ut ON ut.team_id = t.id + WHERE t.id = $1 AND ut.user_id = $2`, teamID, userID). + Scan(&t.ID, &t.Name, &t.CreatedAt, &t.Role) + if err != nil { + return nil, err + } + return &t, nil +} + // CreateTeam creates a team and adds the user as owner atomically. func (s *TeamsStore) CreateTeam(ctx context.Context, userID, name string) (*model.Team, error) { tx, err := s.pool.Begin(ctx) @@ -65,6 +110,29 @@ func (s *TeamsStore) DeleteTeam(ctx context.Context, teamID string) error { return err } +// ListTokens returns all API tokens for a team. +func (s *TeamsStore) ListTokens(ctx context.Context, teamID string) ([]model.APIToken, error) { + rows, err := s.pool.Query(ctx, + `SELECT id, team_id, user_id, name, prefix, last_used_at, created_at + FROM api_tokens + WHERE team_id = $1 + ORDER BY created_at DESC`, teamID) + if err != nil { + return nil, err + } + defer rows.Close() + + var tokens []model.APIToken + for rows.Next() { + var t model.APIToken + if err := rows.Scan(&t.ID, &t.TeamID, &t.UserID, &t.Name, &t.Prefix, &t.LastUsedAt, &t.CreatedAt); err != nil { + return nil, err + } + tokens = append(tokens, t) + } + return tokens, nil +} + // CreateToken inserts a new API token and returns the created token. func (s *TeamsStore) CreateToken(ctx context.Context, teamID, userID, name, tokenHash, prefix string) (*model.APIToken, error) { var token model.APIToken @@ -89,4 +157,3 @@ func (s *TeamsStore) DeleteToken(ctx context.Context, teamID, tokenID string) (i } return tag.RowsAffected(), nil } - diff --git a/internal/testutil/qg_helpers.go b/internal/testutil/qg_helpers.go index 11953ea4..3cbdd6aa 100644 --- a/internal/testutil/qg_helpers.go +++ b/internal/testutil/qg_helpers.go @@ -80,8 +80,8 @@ func InsertNoNewFailuresGate(t *testing.T, ctx context.Context, pool *db.Pool, t func PostEvaluateQG(t *testing.T, pool *db.Pool, teamID, gateID, reportID string) *httptest.ResponseRecorder { t.Helper() h := &handler.QualityGatesHandler{ - Store: store.NewQualityGateStore(pool), - DB: pool, + Store: store.NewQualityGateStore(pool), + ReportStore: store.NewReportsStore(pool), } body := fmt.Sprintf(`{"report_id":%q}`, reportID) req := httptest.NewRequest(http.MethodPost,