diff --git a/src/redis/cache_impl.go b/src/redis/cache_impl.go index 21db8b6d2..266c0a4d9 100644 --- a/src/redis/cache_impl.go +++ b/src/redis/cache_impl.go @@ -1,6 +1,7 @@ package redis import ( + "context" "io" "math/rand" @@ -13,19 +14,19 @@ import ( "github.com/envoyproxy/ratelimit/src/utils" ) -func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) { +func NewRateLimiterCacheImplFromSettings(ctx context.Context, s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) { closer := &utils.MultiCloser{} var perSecondPool Client if s.RedisPerSecond { - perSecondPool = NewClientImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondSocketType, + perSecondPool = NewClientImpl(ctx, srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondSocketType, s.RedisPerSecondType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize, s.RedisPerSecondPipelineWindow, s.RedisPerSecondPipelineLimit, s.RedisTlsConfig, s.RedisHealthCheckActiveConnection, srv, s.RedisPerSecondTimeout, - s.RedisPerSecondPoolOnEmptyBehavior, s.RedisPerSecondSentinelAuth) + s.RedisPerSecondPoolOnEmptyBehavior, s.RedisPerSecondSentinelAuth, s.RedisStartupInitialInterval, s.RedisStartupMaxInterval, s.RedisStartupMaxElapsedTime) closer.Closers = append(closer.Closers, perSecondPool) } - otherPool := NewClientImpl(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisSocketType, s.RedisType, s.RedisUrl, s.RedisPoolSize, + otherPool := NewClientImpl(ctx, srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisSocketType, s.RedisType, s.RedisUrl, s.RedisPoolSize, s.RedisPipelineWindow, s.RedisPipelineLimit, s.RedisTlsConfig, s.RedisHealthCheckActiveConnection, srv, s.RedisTimeout, - s.RedisPoolOnEmptyBehavior, s.RedisSentinelAuth) + s.RedisPoolOnEmptyBehavior, s.RedisSentinelAuth, s.RedisStartupInitialInterval, s.RedisStartupMaxInterval, s.RedisStartupMaxElapsedTime) closer.Closers = append(closer.Closers, otherPool) return NewFixedRateLimitCacheImpl( diff --git a/src/redis/driver_impl.go b/src/redis/driver_impl.go index 50cf44667..f65269ebe 100644 --- a/src/redis/driver_impl.go +++ b/src/redis/driver_impl.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/jpillora/backoff" stats "github.com/lyft/gostats" "github.com/mediocregopher/radix/v4" "github.com/mediocregopher/radix/v4/trace" @@ -119,9 +120,10 @@ func createDialer(timeout time.Duration, useTls bool, tlsConfig *tls.Config, aut return dialer } -func NewClientImpl(scope stats.Scope, useTls bool, auth, redisSocketType, redisType, url string, poolSize int, +func NewClientImpl(ctx context.Context, scope stats.Scope, useTls bool, auth, redisSocketType, redisType, url string, poolSize int, pipelineWindow time.Duration, pipelineLimit int, tlsConfig *tls.Config, healthCheckActiveConnection bool, srv server.Server, timeout time.Duration, poolOnEmptyBehavior string, sentinelAuth string, + startupInitialInterval, startupMaxInterval, startupMaxElapsedTime time.Duration, ) Client { maskedUrl := utils.MaskCredentialsInUrl(url) logger.Warnf("connecting to redis on %s with pool size %d", maskedUrl, poolSize) @@ -191,47 +193,88 @@ func NewClientImpl(scope stats.Scope, useTls bool, auth, redisSocketType, redisT return poolConfig.New(ctx, network, addr) } - var client redisClient - var err error - ctx := context.Background() - - switch strings.ToLower(redisType) { - case "single": - logger.Warnf("Creating single with urls %v", url) - client, err = poolFunc(ctx, redisSocketType, url) - case "cluster": - urls := strings.Split(url, ",") - logger.Warnf("Creating cluster with urls %v", urls) - clusterConfig := radix.ClusterConfig{ - PoolConfig: poolConfig, - } - client, err = clusterConfig.New(ctx, urls) - case "sentinel": + // Validate sentinel URL format early (before retry loop) since it's a configuration error. + if strings.ToLower(redisType) == "sentinel" { urls := strings.Split(url, ",") if len(urls) < 2 { panic(RedisError("Expected master name and a list of urls for the sentinels, in the format: ,,...,")) } + } - // Create sentinel dialer (may use different auth from Redis master/replica) - // sentinelAuth is for Sentinel nodes, auth is for Redis master/replica - sentinelDialer := createDialer(timeout, useTls, tlsConfig, sentinelAuth, fmt.Sprintf("sentinel(%s)", maskedUrl)) + b := &backoff.Backoff{ + Min: startupInitialInterval, + Max: startupMaxInterval, + Factor: 2, + Jitter: true, + } + + startTime := time.Now() - sentinelConfig := radix.SentinelConfig{ - PoolConfig: poolConfig, - SentinelDialer: sentinelDialer, + retryOrDie := func(lastErr error) { + elapsed := time.Since(startTime) + if startupMaxElapsedTime > 0 && elapsed >= startupMaxElapsedTime { + panic(RedisError(fmt.Sprintf("timed out waiting for Redis connection to %s after %s: %v", maskedUrl, elapsed.Round(time.Millisecond), lastErr))) + } + d := b.Duration() + logger.Warnf("Retrying Redis connection to %s in %s (elapsed: %s): %v", maskedUrl, d, elapsed.Round(time.Millisecond), lastErr) + select { + case <-time.After(d): + case <-ctx.Done(): + panic(RedisError(fmt.Sprintf("context cancelled while waiting for Redis connection to %s: %v", maskedUrl, ctx.Err()))) } - client, err = sentinelConfig.New(ctx, urls[0], urls[1:]) - default: - panic(RedisError("Unrecognized redis type " + redisType)) } - checkError(err) + var client redisClient + for { + var err error + switch strings.ToLower(redisType) { + case "single": + logger.Warnf("Creating single with urls %v", url) + client, err = poolFunc(ctx, redisSocketType, url) + case "cluster": + urls := strings.Split(url, ",") + logger.Warnf("Creating cluster with urls %v", urls) + clusterConfig := radix.ClusterConfig{ + PoolConfig: poolConfig, + } + client, err = clusterConfig.New(ctx, urls) + case "sentinel": + urls := strings.Split(url, ",") + sentinelDialer := createDialer(timeout, useTls, tlsConfig, sentinelAuth, fmt.Sprintf("sentinel(%s)", maskedUrl)) + sentinelConfig := radix.SentinelConfig{ + PoolConfig: poolConfig, + SentinelDialer: sentinelDialer, + } + client, err = sentinelConfig.New(ctx, urls[0], urls[1:]) + default: + panic(RedisError("Unrecognized redis type " + redisType)) + } + + if err != nil { + retryOrDie(err) + continue + } - // Check if connection is good - var pingResponse string - checkError(client.Do(ctx, radix.Cmd(&pingResponse, "PING"))) - if pingResponse != "PONG" { - checkError(fmt.Errorf("connecting redis error: %s", pingResponse)) + var pingResponse string + if pingErr := client.Do(ctx, radix.Cmd(&pingResponse, "PING")); pingErr != nil { + _ = client.Close() + retryOrDie(pingErr) + continue + } + if pingResponse != "PONG" { + _ = client.Close() + retryOrDie(fmt.Errorf("unexpected PING response: %q", pingResponse)) + continue + } + + // Successfully connected. + break + } + + if srv != nil { + if err := srv.HealthChecker().Ok(server.RedisHealthComponentName); err != nil { + logger.Errorf("Unable to update health status after Redis connection: %s", err) + } } return &clientImpl{ diff --git a/src/server/health.go b/src/server/health.go index 244a760e7..16ade5eff 100644 --- a/src/server/health.go +++ b/src/server/health.go @@ -49,7 +49,8 @@ func NewHealthChecker(grpcHealthServer *health.Server, name string, healthyWithA ret.healthMap = make(map[string]bool) // Store health states of components into map - ret.healthMap[RedisHealthComponentName] = true + // Redis starts unhealthy; it is marked healthy once the connection is confirmed at startup. + ret.healthMap[RedisHealthComponentName] = false if healthyWithAtLeastOneConfigLoad { // config starts in failed state since we need at least one config loaded to be healthy ret.healthMap[ConfigHealthComponentName] = false @@ -111,7 +112,7 @@ func (hc *HealthChecker) Ok(componentName string) error { // Set component to be healthy hc.healthMap[componentName] = true allComponentsHealthy := areAllComponentsHealthy(hc.healthMap) - + logger.Debugf("Health status of components: %v, all healthy: %t", hc.healthMap, allComponentsHealthy) if allComponentsHealthy { atomic.StoreUint32(&hc.ok, 1) hc.grpc.SetServingStatus(hc.name, healthpb.HealthCheckResponse_SERVING) diff --git a/src/server/server.go b/src/server/server.go index 7202fa2f5..d315c2aef 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "net/http" pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" @@ -17,7 +18,7 @@ type Server interface { * all endpoints have been registered through 'AddHttpEndpoint' * and 'GrpcServer'. */ - Start() + Start(ctx context.Context) /** * Returns the root of the stats tree for the server diff --git a/src/server/server_impl.go b/src/server/server_impl.go index e9402da0a..a82f2821f 100644 --- a/src/server/server_impl.go +++ b/src/server/server_impl.go @@ -8,12 +8,9 @@ import ( "net" "net/http" "net/http/pprof" - "os" - "os/signal" "sort" "strconv" "sync" - "syscall" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" @@ -166,7 +163,7 @@ func (server *server) GrpcServer() *grpc.Server { return server.grpcServer } -func (server *server) Start() { +func (server *server) Start(ctx context.Context) { go func() { logger.Warnf("Listening for debug on '%s'", server.debugAddress) var err error @@ -184,7 +181,7 @@ func (server *server) Start() { go server.startGrpc() - server.handleGracefulShutdown() + server.handleGracefulShutdown(ctx) logger.Warnf("Listening for HTTP on '%s'", server.httpAddress) list, err := reuseport.Listen("tcp", server.httpAddress) @@ -365,16 +362,11 @@ func (server *server) Stop() { server.provider.Stop() } -func (server *server) handleGracefulShutdown() { - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - +func (server *server) handleGracefulShutdown(ctx context.Context) { go func() { - sig := <-sigs - - logger.Infof("Ratelimit server received %v, shutting down gracefully", sig) + <-ctx.Done() + logger.Infof("Context cancelled, stopping server") server.Stop() - os.Exit(0) }() } diff --git a/src/service_cmd/runner/runner.go b/src/service_cmd/runner/runner.go index a3889a58d..d5de0c365 100644 --- a/src/service_cmd/runner/runner.go +++ b/src/service_cmd/runner/runner.go @@ -5,8 +5,11 @@ import ( "io" "math/rand" "net/http" + "os" + "os/signal" "strings" "sync" + "syscall" "time" "github.com/coocood/freecache" @@ -34,6 +37,8 @@ type Runner struct { srv server.Server mu sync.Mutex ratelimitCloser io.Closer + cancel context.CancelFunc + done chan struct{} } func NewRunner(s settings.Settings) Runner { @@ -82,6 +87,7 @@ func NewRunner(s settings.Settings) Runner { return Runner{ statsManager: stats.NewStatManager(store, s), settings: s, + done: make(chan struct{}), } } @@ -89,10 +95,11 @@ func (runner *Runner) GetStatsStore() gostats.Store { return runner.statsManager.GetStatsStore() } -func createLimiter(srv server.Server, s settings.Settings, localCache *freecache.Cache, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) { +func createLimiter(ctx context.Context, srv server.Server, s settings.Settings, localCache *freecache.Cache, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) { switch s.BackendType { case "redis", "": return redis.NewRateLimiterCacheImplFromSettings( + ctx, s, localCache, srv, @@ -116,11 +123,33 @@ func createLimiter(srv server.Server, s settings.Settings, localCache *freecache } func (runner *Runner) Run() { + defer close(runner.done) + + ctx, cancel := context.WithCancel(context.Background()) + runner.mu.Lock() + runner.cancel = cancel + runner.mu.Unlock() + defer cancel() + + // Set up signal handling for graceful shutdown + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + go func() { + select { + case sig := <-sigs: + logger.Infof("Received signal %v, initiating shutdown", sig) + cancel() + case <-ctx.Done(): + } + }() + s := runner.settings if s.TracingEnabled { tp := trace.InitProductionTraceProvider(s.TracingExporterProtocol, s.TracingServiceName, s.TracingServiceNamespace, s.TracingServiceInstanceId, s.TracingSamplingRate) defer func() { - if err := tp.Shutdown(context.Background()); err != nil { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + if err := tp.Shutdown(shutdownCtx); err != nil { logger.Printf("Error shutting down tracer provider: %v", err) } }() @@ -156,8 +185,13 @@ func (runner *Runner) Run() { runner.srv = srv runner.mu.Unlock() - limiter, limiterCloser := createLimiter(srv, s, localCache, runner.statsManager) + limiter, limiterCloser := createLimiter(ctx, srv, s, localCache, runner.statsManager) runner.ratelimitCloser = limiterCloser + defer func() { + if err := limiterCloser.Close(); err != nil { + logger.Errorf("Error closing rate limiter resources: %v", err) + } + }() service := ratelimit.NewService( limiter, @@ -186,18 +220,15 @@ func (runner *Runner) Run() { // v2 proto is no longer supported pb.RegisterRateLimitServiceServer(srv.GrpcServer(), service) - srv.Start() + srv.Start(ctx) } func (runner *Runner) Stop() { runner.mu.Lock() - srv := runner.srv + cancel := runner.cancel runner.mu.Unlock() - if srv != nil { - srv.Stop() - } - - if runner.ratelimitCloser != nil { - _ = runner.ratelimitCloser.Close() + if cancel != nil { + cancel() } + <-runner.done } diff --git a/src/settings/settings.go b/src/settings/settings.go index 6a1be618f..83d657764 100644 --- a/src/settings/settings.go +++ b/src/settings/settings.go @@ -191,6 +191,14 @@ type Settings struct { // See RedisPoolOnEmptyBehavior for possible values and details. RedisPerSecondPoolOnEmptyBehavior string `envconfig:"REDIS_PERSECOND_POOL_ON_EMPTY_BEHAVIOR" default:"WAIT"` + // RedisStartupInitialInterval is the initial backoff interval when retrying Redis connection at startup. + RedisStartupInitialInterval time.Duration `envconfig:"REDIS_STARTUP_INITIAL_INTERVAL" default:"1s"` + // RedisStartupMaxInterval is the maximum backoff interval between Redis connection retries at startup. + RedisStartupMaxInterval time.Duration `envconfig:"REDIS_STARTUP_MAX_INTERVAL" default:"30s"` + // RedisStartupMaxElapsedTime is the total time to keep retrying the Redis connection at startup. + // 0 means retry indefinitely until the connection succeeds. + RedisStartupMaxElapsedTime time.Duration `envconfig:"REDIS_STARTUP_MAX_ELAPSED_TIME" default:"0"` + // Memcache settings MemcacheHostPort []string `envconfig:"MEMCACHE_HOST_PORT" default:""` // MemcacheMaxIdleConns sets the maximum number of idle TCP connections per memcached node. diff --git a/test/redis/bench_test.go b/test/redis/bench_test.go index 2c153ba97..b055dc4e6 100644 --- a/test/redis/bench_test.go +++ b/test/redis/bench_test.go @@ -44,7 +44,7 @@ func BenchmarkParallelDoLimit(b *testing.B) { return func(b *testing.B) { statsStore := gostats.NewStore(gostats.NewNullSink(), false) sm := stats.NewMockStatManager(statsStore) - client := redis.NewClientImpl(statsStore, false, "", "tcp", "single", "127.0.0.1:6379", poolSize, pipelineWindow, pipelineLimit, nil, false, nil, 10*time.Second, "", "") + client := redis.NewClientImpl(context.Background(), statsStore, false, "", "tcp", "single", "127.0.0.1:6379", poolSize, pipelineWindow, pipelineLimit, nil, false, nil, 10*time.Second, "", "", time.Second, 30*time.Second, 0) defer client.Close() cache := redis.NewFixedRateLimitCacheImpl(client, nil, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8, "", sm, true) diff --git a/test/redis/driver_impl_test.go b/test/redis/driver_impl_test.go index 8d1b30132..323c51e37 100644 --- a/test/redis/driver_impl_test.go +++ b/test/redis/driver_impl_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "fmt" "strings" "testing" @@ -38,8 +39,9 @@ func testNewClientImpl(t *testing.T, pipelineWindow time.Duration, pipelineLimit redisAuth := "123" statsStore := stats.NewStore(stats.NewNullSink(), false) + // Use a short maxElapsedTime so failing connection tests don't hang in retry loops. mkRedisClient := func(auth, addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, auth, "tcp", "single", addr, 1, pipelineWindow, pipelineLimit, nil, false, nil, 10*time.Second, "", "") + return redis.NewClientImpl(context.Background(), statsStore, false, auth, "tcp", "single", addr, 1, pipelineWindow, pipelineLimit, nil, false, nil, 10*time.Second, "", "", time.Second, 30*time.Second, 100*time.Millisecond) } t.Run("connection refused", func(t *testing.T) { @@ -66,9 +68,8 @@ func testNewClientImpl(t *testing.T, pipelineWindow time.Duration, pipelineLimit redisSrv.RequireAuth(redisAuth) - assert.PanicsWithError(t, "response returned from Conn: NOAUTH Authentication required.", func() { - mkRedisClient("", redisSrv.Addr()) - }) + panicErr := expectPanicError(t, func() { mkRedisClient("", redisSrv.Addr()) }) + assert.Contains(t, panicErr.Error(), "NOAUTH") }) t.Run("auth pass", func(t *testing.T) { @@ -103,9 +104,8 @@ func testNewClientImpl(t *testing.T, pipelineWindow time.Duration, pipelineLimit redisSrv.RequireUserAuth(user, pass) redisAuth := fmt.Sprintf("%s:invalid-password", user) - assert.PanicsWithError(t, "response returned from Conn: WRONGPASS invalid username-password pair", func() { - mkRedisClient(redisAuth, redisSrv.Addr()) - }) + panicErr := expectPanicError(t, func() { mkRedisClient(redisAuth, redisSrv.Addr()) }) + assert.Contains(t, panicErr.Error(), "WRONGPASS") }) } } @@ -119,7 +119,7 @@ func TestDoCmd(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) mkRedisClient := func(addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, "", "tcp", "single", addr, 1, 0, 0, nil, false, nil, 10*time.Second, "", "") + return redis.NewClientImpl(context.Background(), statsStore, false, "", "tcp", "single", addr, 1, 0, 0, nil, false, nil, 10*time.Second, "", "", time.Second, 30*time.Second, 0) } t.Run("SETGET ok", func(t *testing.T) { @@ -164,7 +164,7 @@ func testPipeDo(t *testing.T, pipelineWindow time.Duration, pipelineLimit int) f statsStore := stats.NewStore(stats.NewNullSink(), false) mkRedisClient := func(addr string) redis.Client { - return redis.NewClientImpl(statsStore, false, "", "tcp", "single", addr, 1, pipelineWindow, pipelineLimit, nil, false, nil, 10*time.Second, "", "") + return redis.NewClientImpl(context.Background(), statsStore, false, "", "tcp", "single", addr, 1, pipelineWindow, pipelineLimit, nil, false, nil, 10*time.Second, "", "", time.Second, 30*time.Second, 0) } t.Run("SETGET ok", func(t *testing.T) { @@ -232,7 +232,7 @@ func TestPoolOnEmptyBehavior(t *testing.T) { // Helper to create client with specific on-empty behavior mkRedisClientWithBehavior := func(addr, behavior string) redis.Client { - return redis.NewClientImpl(statsStore, false, "", "tcp", "single", addr, 1, 0, 0, nil, false, nil, 10*time.Second, behavior, "") + return redis.NewClientImpl(context.Background(), statsStore, false, "", "tcp", "single", addr, 1, 0, 0, nil, false, nil, 10*time.Second, behavior, "", time.Second, 30*time.Second, 0) } t.Run("default behavior (empty string)", func(t *testing.T) { @@ -356,7 +356,8 @@ func TestNewClientImplSentinel(t *testing.T) { mkSentinelClient := func(auth, sentinelAuth, url string, useTls bool, timeout time.Duration) redis.Client { // Pass nil for tlsConfig - we can't test TLS without a real TLS server, // but we can verify the code path is executed (logs will show TLS is enabled) - return redis.NewClientImpl(statsStore, useTls, auth, "tcp", "sentinel", url, 1, 0, 0, nil, false, nil, timeout, "", sentinelAuth) + // Use a short maxElapsedTime so failing connection tests don't hang in retry loops. + return redis.NewClientImpl(context.Background(), statsStore, useTls, auth, "tcp", "sentinel", url, 1, 0, 0, nil, false, nil, timeout, "", sentinelAuth, time.Second, 30*time.Second, 100*time.Millisecond) } t.Run("invalid url format - missing sentinel addresses", func(t *testing.T) { diff --git a/test/server/health_test.go b/test/server/health_test.go index a2238a5c5..40e7b5187 100644 --- a/test/server/health_test.go +++ b/test/server/health_test.go @@ -18,117 +18,81 @@ import ( func TestHealthCheck(t *testing.T) { defer signal.Reset(syscall.SIGTERM) - recorder := httptest.NewRecorder() - hc := server.NewHealthChecker(health.NewServer(), "ratelimit", false) - r, _ := http.NewRequest("GET", "http://1.2.3.4/healthcheck", nil) - hc.ServeHTTP(recorder, r) - - if recorder.Code != 200 { - t.Errorf("expected code 200 actual %d", recorder.Code) + checkHTTP := func(wantCode int) { + recorder := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "http://1.2.3.4/healthcheck", nil) + hc.ServeHTTP(recorder, r) + if recorder.Code != wantCode { + t.Errorf("expected code %d actual %d", wantCode, recorder.Code) + } } - if recorder.Body.String() != "OK" { - t.Errorf("expected body 'OK', got '%s'", recorder.Body.String()) - } + // Redis starts unhealthy until the connection is confirmed. + checkHTTP(500) - err := hc.Fail(server.RedisHealthComponentName) + err := hc.Ok(server.RedisHealthComponentName) if err != nil { t.Errorf("Expected no errors for updating redis health status") } + checkHTTP(200) - recorder = httptest.NewRecorder() - - r, _ = http.NewRequest("GET", "http://1.2.3.4/healthcheck", nil) - hc.ServeHTTP(recorder, r) - - if recorder.Code != 500 { - t.Errorf("expected code 500 actual %d", recorder.Code) + err = hc.Fail(server.RedisHealthComponentName) + if err != nil { + t.Errorf("Expected no errors for updating redis health status") } + checkHTTP(500) err = hc.Ok(server.RedisHealthComponentName) if err != nil { t.Errorf("Expected no errors for updating redis health status") } - - recorder = httptest.NewRecorder() - - r, _ = http.NewRequest("GET", "http://1.2.3.4/healthcheck", nil) - hc.ServeHTTP(recorder, r) - - if recorder.Code != 200 { - t.Errorf("expected code 200 actual %d", recorder.Code) - } - - if recorder.Body.String() != "OK" { - t.Errorf("expected body 'OK', got '%s'", recorder.Body.String()) - } + checkHTTP(200) } func TestHealthyWithAtLeastOneConfigLoaded(t *testing.T) { defer signal.Reset(syscall.SIGTERM) - recorder := httptest.NewRecorder() - hc := server.NewHealthChecker(health.NewServer(), "ratelimit", true) - r, _ := http.NewRequest("GET", "http://1.2.3.4/healthcheck", nil) - hc.ServeHTTP(recorder, r) - - if recorder.Code != 500 { - t.Errorf("expected code 500 actual %d", recorder.Code) + checkHTTP := func(wantCode int) { + recorder := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "http://1.2.3.4/healthcheck", nil) + hc.ServeHTTP(recorder, r) + if recorder.Code != wantCode { + t.Errorf("expected code %d actual %d", wantCode, recorder.Code) + } } + // Both Redis and config start unhealthy. + checkHTTP(500) + err := hc.Ok(server.ConfigHealthComponentName) if err != nil { t.Errorf("Expected no errors for updating config health status") } + // Config is ready but Redis still unhealthy. + checkHTTP(500) - recorder = httptest.NewRecorder() - - r, _ = http.NewRequest("GET", "http://1.2.3.4/healthcheck", nil) - hc.ServeHTTP(recorder, r) - - if recorder.Code != 200 { - t.Errorf("expected code 200 actual %d", recorder.Code) - } - - if recorder.Body.String() != "OK" { - t.Errorf("expected body 'OK', got '%s'", recorder.Body.String()) + err = hc.Ok(server.RedisHealthComponentName) + if err != nil { + t.Errorf("Expected no errors for updating redis health status") } + // Both ready now. + checkHTTP(200) err = hc.Fail(server.RedisHealthComponentName) if err != nil { t.Errorf("Expected no errors for updating redis health status") } - - recorder = httptest.NewRecorder() - - r, _ = http.NewRequest("GET", "http://1.2.3.4/healthcheck", nil) - hc.ServeHTTP(recorder, r) - - if recorder.Code != 500 { - t.Errorf("expected code 500 actual %d", recorder.Code) - } + checkHTTP(500) err = hc.Ok(server.RedisHealthComponentName) if err != nil { t.Errorf("Expected no errors for updating redis health status") } - - recorder = httptest.NewRecorder() - - r, _ = http.NewRequest("GET", "http://1.2.3.4/healthcheck", nil) - hc.ServeHTTP(recorder, r) - - if recorder.Code != 200 { - t.Errorf("expected code 200 actual %d", recorder.Code) - } - - if recorder.Body.String() != "OK" { - t.Errorf("expected body 'OK', got '%s'", recorder.Body.String()) - } + checkHTTP(200) } func TestGrpcHealthCheck(t *testing.T) { @@ -142,9 +106,10 @@ func TestGrpcHealthCheck(t *testing.T) { Service: "ratelimit", } + // Redis starts unhealthy until the connection is confirmed. res, _ := grpcHealthServer.Check(context.Background(), req) - if healthpb.HealthCheckResponse_SERVING != res.Status { - t.Errorf("expected status SERVING actual %v", res.Status) + if healthpb.HealthCheckResponse_NOT_SERVING != res.Status { + t.Errorf("expected status NOT_SERVING actual %v", res.Status) } err := hc.Ok(server.RedisHealthComponentName) diff --git a/test/service/ratelimit_test.go b/test/service/ratelimit_test.go index c77ba6b87..dc4ab8a7f 100644 --- a/test/service/ratelimit_test.go +++ b/test/service/ratelimit_test.go @@ -97,6 +97,8 @@ func commonSetup(t *testing.T) rateLimitServiceTestSuite { ret.statStore = gostats.NewStore(gostats.NewNullSink(), false) ret.statsManager = mock_stats.NewMockStatManager(ret.statStore) ret.health = server.NewHealthChecker(health.NewServer(), "ratelimit", false) + // Tests use a mocked cache, so simulate a successful Redis connection. + _ = ret.health.Ok(server.RedisHealthComponentName) return ret } @@ -597,6 +599,8 @@ func TestServiceHealthStatus(test *testing.T) { healthyWithAtLeastOneConfigLoaded := false grpcHealthServer := health.NewServer() hc := server.NewHealthChecker(grpcHealthServer, "ratelimit", healthyWithAtLeastOneConfigLoaded) + // Tests use a mocked cache, so simulate a successful Redis connection. + _ = hc.Ok(server.RedisHealthComponentName) healthpb.RegisterHealthServer(grpc.NewServer(), grpcHealthServer) // Set up the service @@ -623,6 +627,8 @@ func TestServiceHealthStatusAtLeastOneConfigLoaded(test *testing.T) { healthyWithAtLeastOneConfigLoaded := true grpcHealthServer := health.NewServer() hc := server.NewHealthChecker(grpcHealthServer, "ratelimit", healthyWithAtLeastOneConfigLoaded) + // Tests use a mocked cache, so simulate a successful Redis connection. + _ = hc.Ok(server.RedisHealthComponentName) healthpb.RegisterHealthServer(grpc.NewServer(), grpcHealthServer) // Set up the service