Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/redis/cache_impl.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package redis

import (
"context"
"io"
"math/rand"

Expand All @@ -13,19 +14,19 @@ import (
"github.com/envoyproxy/ratelimit/src/utils"
)

func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) {
func NewRateLimiterCacheImplFromSettings(ctx context.Context, s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) {
closer := &utils.MultiCloser{}
var perSecondPool Client
if s.RedisPerSecond {
perSecondPool = NewClientImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondSocketType,
perSecondPool = NewClientImpl(ctx, srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondSocketType,
s.RedisPerSecondType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize, s.RedisPerSecondPipelineWindow, s.RedisPerSecondPipelineLimit, s.RedisTlsConfig, s.RedisHealthCheckActiveConnection, srv, s.RedisPerSecondTimeout,
s.RedisPerSecondPoolOnEmptyBehavior, s.RedisPerSecondSentinelAuth)
s.RedisPerSecondPoolOnEmptyBehavior, s.RedisPerSecondSentinelAuth, s.RedisStartupInitialInterval, s.RedisStartupMaxInterval, s.RedisStartupMaxElapsedTime)
closer.Closers = append(closer.Closers, perSecondPool)
}

otherPool := NewClientImpl(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisSocketType, s.RedisType, s.RedisUrl, s.RedisPoolSize,
otherPool := NewClientImpl(ctx, srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisSocketType, s.RedisType, s.RedisUrl, s.RedisPoolSize,
s.RedisPipelineWindow, s.RedisPipelineLimit, s.RedisTlsConfig, s.RedisHealthCheckActiveConnection, srv, s.RedisTimeout,
s.RedisPoolOnEmptyBehavior, s.RedisSentinelAuth)
s.RedisPoolOnEmptyBehavior, s.RedisSentinelAuth, s.RedisStartupInitialInterval, s.RedisStartupMaxInterval, s.RedisStartupMaxElapsedTime)
closer.Closers = append(closer.Closers, otherPool)

return NewFixedRateLimitCacheImpl(
Expand Down
107 changes: 75 additions & 32 deletions src/redis/driver_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"
"time"

"github.com/jpillora/backoff"
stats "github.com/lyft/gostats"
"github.com/mediocregopher/radix/v4"
"github.com/mediocregopher/radix/v4/trace"
Expand Down Expand Up @@ -119,9 +120,10 @@ func createDialer(timeout time.Duration, useTls bool, tlsConfig *tls.Config, aut
return dialer
}

func NewClientImpl(scope stats.Scope, useTls bool, auth, redisSocketType, redisType, url string, poolSize int,
func NewClientImpl(ctx context.Context, scope stats.Scope, useTls bool, auth, redisSocketType, redisType, url string, poolSize int,
pipelineWindow time.Duration, pipelineLimit int, tlsConfig *tls.Config, healthCheckActiveConnection bool, srv server.Server,
timeout time.Duration, poolOnEmptyBehavior string, sentinelAuth string,
startupInitialInterval, startupMaxInterval, startupMaxElapsedTime time.Duration,
) Client {
maskedUrl := utils.MaskCredentialsInUrl(url)
logger.Warnf("connecting to redis on %s with pool size %d", maskedUrl, poolSize)
Expand Down Expand Up @@ -191,47 +193,88 @@ func NewClientImpl(scope stats.Scope, useTls bool, auth, redisSocketType, redisT
return poolConfig.New(ctx, network, addr)
}

var client redisClient
var err error
ctx := context.Background()

switch strings.ToLower(redisType) {
case "single":
logger.Warnf("Creating single with urls %v", url)
client, err = poolFunc(ctx, redisSocketType, url)
case "cluster":
urls := strings.Split(url, ",")
logger.Warnf("Creating cluster with urls %v", urls)
clusterConfig := radix.ClusterConfig{
PoolConfig: poolConfig,
}
client, err = clusterConfig.New(ctx, urls)
case "sentinel":
// Validate sentinel URL format early (before retry loop) since it's a configuration error.
if strings.ToLower(redisType) == "sentinel" {
urls := strings.Split(url, ",")
if len(urls) < 2 {
panic(RedisError("Expected master name and a list of urls for the sentinels, in the format: <redis master name>,<sentinel1>,...,<sentineln>"))
}
}

// Create sentinel dialer (may use different auth from Redis master/replica)
// sentinelAuth is for Sentinel nodes, auth is for Redis master/replica
sentinelDialer := createDialer(timeout, useTls, tlsConfig, sentinelAuth, fmt.Sprintf("sentinel(%s)", maskedUrl))
b := &backoff.Backoff{
Min: startupInitialInterval,
Max: startupMaxInterval,
Factor: 2,
Jitter: true,
}

startTime := time.Now()

sentinelConfig := radix.SentinelConfig{
PoolConfig: poolConfig,
SentinelDialer: sentinelDialer,
retryOrDie := func(lastErr error) {
elapsed := time.Since(startTime)
if startupMaxElapsedTime > 0 && elapsed >= startupMaxElapsedTime {
panic(RedisError(fmt.Sprintf("timed out waiting for Redis connection to %s after %s: %v", maskedUrl, elapsed.Round(time.Millisecond), lastErr)))
}
d := b.Duration()
logger.Warnf("Retrying Redis connection to %s in %s (elapsed: %s): %v", maskedUrl, d, elapsed.Round(time.Millisecond), lastErr)
select {
case <-time.After(d):
case <-ctx.Done():
panic(RedisError(fmt.Sprintf("context cancelled while waiting for Redis connection to %s: %v", maskedUrl, ctx.Err())))
}
client, err = sentinelConfig.New(ctx, urls[0], urls[1:])
default:
panic(RedisError("Unrecognized redis type " + redisType))
}

checkError(err)
var client redisClient
for {
var err error
switch strings.ToLower(redisType) {
case "single":
logger.Warnf("Creating single with urls %v", url)
client, err = poolFunc(ctx, redisSocketType, url)
case "cluster":
urls := strings.Split(url, ",")
logger.Warnf("Creating cluster with urls %v", urls)
clusterConfig := radix.ClusterConfig{
PoolConfig: poolConfig,
}
client, err = clusterConfig.New(ctx, urls)
case "sentinel":
urls := strings.Split(url, ",")
sentinelDialer := createDialer(timeout, useTls, tlsConfig, sentinelAuth, fmt.Sprintf("sentinel(%s)", maskedUrl))
sentinelConfig := radix.SentinelConfig{
PoolConfig: poolConfig,
SentinelDialer: sentinelDialer,
}
client, err = sentinelConfig.New(ctx, urls[0], urls[1:])
default:
panic(RedisError("Unrecognized redis type " + redisType))
}

if err != nil {
retryOrDie(err)
continue
}

// Check if connection is good
var pingResponse string
checkError(client.Do(ctx, radix.Cmd(&pingResponse, "PING")))
if pingResponse != "PONG" {
checkError(fmt.Errorf("connecting redis error: %s", pingResponse))
var pingResponse string
if pingErr := client.Do(ctx, radix.Cmd(&pingResponse, "PING")); pingErr != nil {
_ = client.Close()
retryOrDie(pingErr)
continue
}
if pingResponse != "PONG" {
_ = client.Close()
retryOrDie(fmt.Errorf("unexpected PING response: %q", pingResponse))
continue
}

// Successfully connected.
break
}

if srv != nil {
if err := srv.HealthChecker().Ok(server.RedisHealthComponentName); err != nil {
logger.Errorf("Unable to update health status after Redis connection: %s", err)
}
}

return &clientImpl{
Expand Down
5 changes: 3 additions & 2 deletions src/server/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func NewHealthChecker(grpcHealthServer *health.Server, name string, healthyWithA

ret.healthMap = make(map[string]bool)
// Store health states of components into map
ret.healthMap[RedisHealthComponentName] = true
// Redis starts unhealthy; it is marked healthy once the connection is confirmed at startup.
ret.healthMap[RedisHealthComponentName] = false
if healthyWithAtLeastOneConfigLoad {
// config starts in failed state since we need at least one config loaded to be healthy
ret.healthMap[ConfigHealthComponentName] = false
Expand Down Expand Up @@ -111,7 +112,7 @@ func (hc *HealthChecker) Ok(componentName string) error {
// Set component to be healthy
hc.healthMap[componentName] = true
allComponentsHealthy := areAllComponentsHealthy(hc.healthMap)

logger.Debugf("Health status of components: %v, all healthy: %t", hc.healthMap, allComponentsHealthy)
if allComponentsHealthy {
atomic.StoreUint32(&hc.ok, 1)
hc.grpc.SetServingStatus(hc.name, healthpb.HealthCheckResponse_SERVING)
Expand Down
3 changes: 2 additions & 1 deletion src/server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"net/http"

pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3"
Expand All @@ -17,7 +18,7 @@ type Server interface {
* all endpoints have been registered through 'AddHttpEndpoint'
* and 'GrpcServer'.
*/
Start()
Start(ctx context.Context)

/**
* Returns the root of the stats tree for the server
Expand Down
18 changes: 5 additions & 13 deletions src/server/server_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@ import (
"net"
"net/http"
"net/http/pprof"
"os"
"os/signal"
"sort"
"strconv"
"sync"
"syscall"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
Expand Down Expand Up @@ -166,7 +163,7 @@ func (server *server) GrpcServer() *grpc.Server {
return server.grpcServer
}

func (server *server) Start() {
func (server *server) Start(ctx context.Context) {
go func() {
logger.Warnf("Listening for debug on '%s'", server.debugAddress)
var err error
Expand All @@ -184,7 +181,7 @@ func (server *server) Start() {

go server.startGrpc()

server.handleGracefulShutdown()
server.handleGracefulShutdown(ctx)

logger.Warnf("Listening for HTTP on '%s'", server.httpAddress)
list, err := reuseport.Listen("tcp", server.httpAddress)
Expand Down Expand Up @@ -365,16 +362,11 @@ func (server *server) Stop() {
server.provider.Stop()
}

func (server *server) handleGracefulShutdown() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)

func (server *server) handleGracefulShutdown(ctx context.Context) {
go func() {
sig := <-sigs

logger.Infof("Ratelimit server received %v, shutting down gracefully", sig)
<-ctx.Done()
logger.Infof("Context cancelled, stopping server")
server.Stop()
os.Exit(0)
}()
}

Expand Down
53 changes: 42 additions & 11 deletions src/service_cmd/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"io"
"math/rand"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"

"github.com/coocood/freecache"
Expand Down Expand Up @@ -34,6 +37,8 @@ type Runner struct {
srv server.Server
mu sync.Mutex
ratelimitCloser io.Closer
cancel context.CancelFunc
done chan struct{}
}

func NewRunner(s settings.Settings) Runner {
Expand Down Expand Up @@ -82,17 +87,19 @@ func NewRunner(s settings.Settings) Runner {
return Runner{
statsManager: stats.NewStatManager(store, s),
settings: s,
done: make(chan struct{}),
}
}

func (runner *Runner) GetStatsStore() gostats.Store {
return runner.statsManager.GetStatsStore()
}

func createLimiter(srv server.Server, s settings.Settings, localCache *freecache.Cache, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) {
func createLimiter(ctx context.Context, srv server.Server, s settings.Settings, localCache *freecache.Cache, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) {
switch s.BackendType {
case "redis", "":
return redis.NewRateLimiterCacheImplFromSettings(
ctx,
s,
localCache,
srv,
Expand All @@ -116,11 +123,33 @@ func createLimiter(srv server.Server, s settings.Settings, localCache *freecache
}

func (runner *Runner) Run() {
defer close(runner.done)

ctx, cancel := context.WithCancel(context.Background())
runner.mu.Lock()
runner.cancel = cancel
runner.mu.Unlock()
defer cancel()

// Set up signal handling for graceful shutdown
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
go func() {
select {
case sig := <-sigs:
logger.Infof("Received signal %v, initiating shutdown", sig)
cancel()
case <-ctx.Done():
}
}()

s := runner.settings
if s.TracingEnabled {
tp := trace.InitProductionTraceProvider(s.TracingExporterProtocol, s.TracingServiceName, s.TracingServiceNamespace, s.TracingServiceInstanceId, s.TracingSamplingRate)
defer func() {
if err := tp.Shutdown(context.Background()); err != nil {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := tp.Shutdown(shutdownCtx); err != nil {
logger.Printf("Error shutting down tracer provider: %v", err)
}
}()
Expand Down Expand Up @@ -156,8 +185,13 @@ func (runner *Runner) Run() {
runner.srv = srv
runner.mu.Unlock()

limiter, limiterCloser := createLimiter(srv, s, localCache, runner.statsManager)
limiter, limiterCloser := createLimiter(ctx, srv, s, localCache, runner.statsManager)
runner.ratelimitCloser = limiterCloser
defer func() {
if err := limiterCloser.Close(); err != nil {
logger.Errorf("Error closing rate limiter resources: %v", err)
}
}()

service := ratelimit.NewService(
limiter,
Expand Down Expand Up @@ -186,18 +220,15 @@ func (runner *Runner) Run() {
// v2 proto is no longer supported
pb.RegisterRateLimitServiceServer(srv.GrpcServer(), service)

srv.Start()
srv.Start(ctx)
}

func (runner *Runner) Stop() {
runner.mu.Lock()
srv := runner.srv
cancel := runner.cancel
runner.mu.Unlock()
if srv != nil {
srv.Stop()
}

if runner.ratelimitCloser != nil {
_ = runner.ratelimitCloser.Close()
if cancel != nil {
cancel()
}
<-runner.done
}
8 changes: 8 additions & 0 deletions src/settings/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ type Settings struct {
// See RedisPoolOnEmptyBehavior for possible values and details.
RedisPerSecondPoolOnEmptyBehavior string `envconfig:"REDIS_PERSECOND_POOL_ON_EMPTY_BEHAVIOR" default:"WAIT"`

// RedisStartupInitialInterval is the initial backoff interval when retrying Redis connection at startup.
RedisStartupInitialInterval time.Duration `envconfig:"REDIS_STARTUP_INITIAL_INTERVAL" default:"1s"`
// RedisStartupMaxInterval is the maximum backoff interval between Redis connection retries at startup.
RedisStartupMaxInterval time.Duration `envconfig:"REDIS_STARTUP_MAX_INTERVAL" default:"30s"`
// RedisStartupMaxElapsedTime is the total time to keep retrying the Redis connection at startup.
// 0 means retry indefinitely until the connection succeeds.
RedisStartupMaxElapsedTime time.Duration `envconfig:"REDIS_STARTUP_MAX_ELAPSED_TIME" default:"0"`

// Memcache settings
MemcacheHostPort []string `envconfig:"MEMCACHE_HOST_PORT" default:""`
// MemcacheMaxIdleConns sets the maximum number of idle TCP connections per memcached node.
Expand Down
2 changes: 1 addition & 1 deletion test/redis/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func BenchmarkParallelDoLimit(b *testing.B) {
return func(b *testing.B) {
statsStore := gostats.NewStore(gostats.NewNullSink(), false)
sm := stats.NewMockStatManager(statsStore)
client := redis.NewClientImpl(statsStore, false, "", "tcp", "single", "127.0.0.1:6379", poolSize, pipelineWindow, pipelineLimit, nil, false, nil, 10*time.Second, "", "")
client := redis.NewClientImpl(context.Background(), statsStore, false, "", "tcp", "single", "127.0.0.1:6379", poolSize, pipelineWindow, pipelineLimit, nil, false, nil, 10*time.Second, "", "", time.Second, 30*time.Second, 0)
defer client.Close()

cache := redis.NewFixedRateLimitCacheImpl(client, nil, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8, "", sm, true)
Expand Down
Loading
Loading