diff --git a/client.go b/client.go index 13fb1653..6d810d89 100644 --- a/client.go +++ b/client.go @@ -30,10 +30,8 @@ import ( "github.com/riverqueue/river/rivershared/riverpilot" "github.com/riverqueue/river/rivershared/riversharedmaintenance" "github.com/riverqueue/river/rivershared/startstop" - "github.com/riverqueue/river/rivershared/testsignal" "github.com/riverqueue/river/rivershared/util/dbutil" "github.com/riverqueue/river/rivershared/util/maputil" - "github.com/riverqueue/river/rivershared/util/serviceutil" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivershared/util/testutil" "github.com/riverqueue/river/rivershared/util/valutil" @@ -605,18 +603,10 @@ type Client[TTx any] struct { notifier *notifier.Notifier // may be nil in poll-only mode periodicJobs *PeriodicJobBundle pilot riverpilot.Pilot - producersByQueueName map[string]*producer - queueMaintainer *maintenance.QueueMaintainer - - // queueMaintainerEpoch is incremented each time leadership is gained, - // giving each tryStartQueueMaintainer goroutine a term number. - // queueMaintainerMu serializes epoch checks with Stop calls so that a - // stale goroutine from an older term cannot tear down a maintainer - // started by a newer term. - queueMaintainerEpoch int64 - queueMaintainerMu sync.Mutex - - queues *QueueBundle + producersByQueueName map[string]*producer + queueMaintainer *maintenance.QueueMaintainer + queueMaintainerLeader *maintenance.QueueMaintainerLeader + queues *QueueBundle services []startstop.Service stopped <-chan struct{} subscriptionManager *subscriptionManager @@ -629,23 +619,16 @@ type Client[TTx any] struct { // Test-only signals. type clientTestSignals struct { - electedLeader testsignal.TestSignal[struct{}] // notifies when elected leader - queueMaintainerStartError testsignal.TestSignal[error] // notifies on each failed queue maintainer start attempt - queueMaintainerStartRetriesExhausted testsignal.TestSignal[struct{}] // notifies when leader resignation is requested after all queue maintainer start retries have been exhausted - - jobCleaner *maintenance.JobCleanerTestSignals - jobRescuer *maintenance.JobRescuerTestSignals - jobScheduler *maintenance.JobSchedulerTestSignals - periodicJobEnqueuer *maintenance.PeriodicJobEnqueuerTestSignals - queueCleaner *maintenance.QueueCleanerTestSignals - reindexer *maintenance.ReindexerTestSignals + jobCleaner *maintenance.JobCleanerTestSignals + jobRescuer *maintenance.JobRescuerTestSignals + jobScheduler *maintenance.JobSchedulerTestSignals + periodicJobEnqueuer *maintenance.PeriodicJobEnqueuerTestSignals + queueCleaner *maintenance.QueueCleanerTestSignals + queueMaintainerLeader *maintenance.QueueMaintainerLeaderTestSignals + reindexer *maintenance.ReindexerTestSignals } func (ts *clientTestSignals) Init(tb testutil.TestingTB) { - ts.electedLeader.Init(tb) - ts.queueMaintainerStartError.Init(tb) - ts.queueMaintainerStartRetriesExhausted.Init(tb) - if ts.jobCleaner != nil { ts.jobCleaner.Init(tb) } @@ -661,6 +644,9 @@ func (ts *clientTestSignals) Init(tb testutil.TestingTB) { if ts.queueCleaner != nil { ts.queueCleaner.Init(tb) } + if ts.queueMaintainerLeader != nil { + ts.queueMaintainerLeader.Init(tb) + } if ts.reindexer != nil { ts.reindexer.Init(tb) } @@ -867,9 +853,6 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client client.services = append(client.services, startstop.StartStopFunc(client.logStatsLoop)) - client.services = append(client.services, - startstop.StartStopFunc(client.handleLeadershipChangeLoop)) - if pluginPilot != nil { client.services = append(client.services, pluginPilot.PluginServices()...) } @@ -972,6 +955,15 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client if config.TestOnly { client.queueMaintainer.StaggerStartupDisable(true) } + + client.queueMaintainerLeader = maintenance.NewQueueMaintainerLeader(archetype, &maintenance.QueueMaintainerLeaderConfig{ + ClientID: config.ID, + Elector: client.elector, + QueueMaintainer: client.queueMaintainer, + RequestResignFunc: client.clientNotifyBundle.RequestResign, + }) + client.services = append(client.services, client.queueMaintainerLeader) + client.testSignals.queueMaintainerLeader = &client.queueMaintainerLeader.TestSignals } return client, nil @@ -1292,147 +1284,6 @@ func (c *Client[TTx]) logStatsLoop(ctx context.Context, shouldStart bool, starte return nil } -func (c *Client[TTx]) handleLeadershipChangeLoop(ctx context.Context, shouldStart bool, started, stopped func()) error { - if !shouldStart { - return nil - } - - go func() { - started() - defer stopped() // this defer should come first so it's last out - - sub := c.elector.Listen() - defer sub.Unlisten() - - // Cancel function for an in-progress tryStartQueueMaintainer. If - // leadership is lost while the start process is still retrying, used to - // abort it promptly instead of waiting for retries to finish. - var cancelQueueMaintainerStart context.CancelCauseFunc = func(_ error) {} - - for { - select { - case <-ctx.Done(): - cancelQueueMaintainerStart(context.Cause(ctx)) - return - - case notification := <-sub.C(): - c.baseService.Logger.DebugContext(ctx, c.baseService.Name+": Election change received", - slog.String("client_id", c.config.ID), slog.Bool("is_leader", notification.IsLeader)) - - switch { - case notification.IsLeader: - // Starting the queue maintainer takes time, so send the - // test signal first. Tests waiting on it can receive it, - // cancel the queue maintainer start, and finish faster. - c.testSignals.electedLeader.Signal(struct{}{}) - - // Start the queue maintainer with retries and exponential - // backoff in a separate goroutine so the leadership change - // loop remains responsive to new notifications. startCtx is - // used for cancellation in case leadership is lost while - // retries are in progress. - // - // Epoch is incremented so stale tryStartQueueMaintainer - // goroutines from a previous term cannot call Stop after a - // new term has begun. - var startCtx context.Context - startCtx, cancelQueueMaintainerStart = context.WithCancelCause(ctx) - - c.queueMaintainerMu.Lock() - c.queueMaintainerEpoch++ - epoch := c.queueMaintainerEpoch - c.queueMaintainerMu.Unlock() - - go c.tryStartQueueMaintainer(startCtx, epoch) - - default: - // Cancel any in-progress start attempts before stopping. - // Send a startstop.ErrStop to make sure services like - // Reindexer run any specific cleanup code for stops. - cancelQueueMaintainerStart(startstop.ErrStop) - cancelQueueMaintainerStart = func(_ error) {} - - c.queueMaintainer.Stop() - } - } - } - }() - - return nil -} - -// Tries to start the queue maintainer after gaining leadership. We allow some -// retries with exponential backoff in case of failure, and in case the queue -// maintainer can't be started, we request resignation to allow another client -// to try and take over. -func (c *Client[TTx]) tryStartQueueMaintainer(ctx context.Context, epoch int64) { - const maxStartAttempts = 3 - - ctxCancelled := func() bool { - if ctx.Err() != nil { - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Queue maintainer start cancelled") - return true - } - return false - } - - // stopIfCurrentEpoch atomically checks whether this goroutine's epoch is - // still the active one and calls Stop only if it is. Combined with the - // epoch increment in handleLeadershipChangeLoop, prevents stale goroutine - // from stopping a maintainer started by a newer leadership term. - stopIfCurrentEpoch := func() bool { - c.queueMaintainerMu.Lock() - defer c.queueMaintainerMu.Unlock() - - if c.queueMaintainerEpoch != epoch { - return false - } - - c.queueMaintainer.Stop() - return true - } - - var lastErr error - for attempt := 1; attempt <= maxStartAttempts; attempt++ { - if ctxCancelled() { - return - } - - if lastErr = c.queueMaintainer.Start(ctx); lastErr == nil { - return - } - - c.baseService.Logger.ErrorContext(ctx, c.baseService.Name+": Error starting queue maintainer", - slog.String("err", lastErr.Error()), slog.Int("attempt", attempt)) - - c.testSignals.queueMaintainerStartError.Signal(lastErr) - - // Stop the queue maintainer to fully reset its state (and any - // sub-services) before retrying. The epoch check ensures a stale - // goroutine cannot stop a maintainer from a newer leadership term. - if !stopIfCurrentEpoch() { - return - } - - if attempt < maxStartAttempts { - serviceutil.CancellableSleep(ctx, serviceutil.ExponentialBackoff(attempt, serviceutil.MaxAttemptsBeforeResetDefault)) - } - } - - if ctxCancelled() { - return - } - - c.baseService.Logger.ErrorContext(ctx, c.baseService.Name+": Queue maintainer failed to start after all attempts, requesting leader resignation", - slog.String("err", lastErr.Error())) - - c.testSignals.queueMaintainerStartRetriesExhausted.Signal(struct{}{}) - - if err := c.clientNotifyBundle.RequestResign(ctx); err != nil { - c.baseService.Logger.ErrorContext(ctx, c.baseService.Name+": Error requesting leader resignation", slog.String("err", err.Error())) - } -} - // Driver exposes the underlying driver used by the client. // // API is not stable. DO NOT USE. diff --git a/client_pilot_test.go b/client_pilot_test.go index 23a7b51d..7f47cd4b 100644 --- a/client_pilot_test.go +++ b/client_pilot_test.go @@ -281,7 +281,7 @@ func Test_Client_PilotUsage(t *testing.T) { pilot.testSignals.Init(t) startClient(ctx, t, client) - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() pilot.testSignals.PeriodicJobGetAll.WaitOrTimeout() pilot.testSignals.PeriodicJobUpsertMany.WaitOrTimeout() diff --git a/client_test.go b/client_test.go index 5a446068..87d140c9 100644 --- a/client_test.go +++ b/client_test.go @@ -1116,7 +1116,7 @@ func Test_Client_Common(t *testing.T) { startClient(ctx, t, client) client.config.Logger.InfoContext(ctx, "Test waiting for client to be elected leader for the first time") - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() client.config.Logger.InfoContext(ctx, "Client was elected leader for the first time") // We test the function with a forced resignation, but this is a general @@ -1124,7 +1124,7 @@ func Test_Client_Common(t *testing.T) { require.NoError(t, client.Notify().RequestResign(ctx)) client.config.Logger.InfoContext(ctx, "Test waiting for client to be elected leader after forced resignation") - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() client.config.Logger.InfoContext(ctx, "Client was elected leader after forced resignation") }) @@ -1137,7 +1137,7 @@ func Test_Client_Common(t *testing.T) { startClient(ctx, t, client) client.config.Logger.InfoContext(ctx, "Test waiting for client to be elected leader for the first time") - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() client.config.Logger.InfoContext(ctx, "Client was elected leader for the first time") tx, err := bundle.dbPool.Begin(ctx) @@ -1149,7 +1149,7 @@ func Test_Client_Common(t *testing.T) { require.NoError(t, tx.Commit(ctx)) client.config.Logger.InfoContext(ctx, "Test waiting for client to be elected leader after forced resignation") - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() client.config.Logger.InfoContext(ctx, "Client was elected leader after forced resignation") }) @@ -1422,7 +1422,7 @@ func Test_Client_Common(t *testing.T) { // Despite no notifier, the client should still be able to elect itself // leader. - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() event := riversharedtest.WaitOrTimeout(t, subscribeChan) require.Equal(t, EventKindJobCompleted, event.Kind) @@ -1450,7 +1450,7 @@ func Test_Client_Common(t *testing.T) { // Despite no notifier, the client should still be able to elect itself // leader. - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() event := riversharedtest.WaitOrTimeout(t, subscribeChan) require.Equal(t, EventKindJobCompleted, event.Kind) @@ -4881,7 +4881,7 @@ func Test_Client_Maintenance(t *testing.T) { t.Helper() startClient(ctx, t, client) - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() riversharedtest.WaitOrTimeout(t, client.queueMaintainer.Started()) } @@ -5158,16 +5158,16 @@ func Test_Client_Maintenance(t *testing.T) { client, _ := setup(t, config) startClient(ctx, t, client) - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() // Wait for all 3 retry attempts to fail. for range 3 { - err := client.testSignals.queueMaintainerStartError.WaitOrTimeout() + err := client.queueMaintainerLeader.TestSignals.StartError.WaitOrTimeout() require.EqualError(t, err, "hook start error") } // After all retries exhausted, the client should request resignation. - client.testSignals.queueMaintainerStartRetriesExhausted.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.StartRetriesExhausted.WaitOrTimeout() }) t.Run("PeriodicJobEnqueuerWithInsertOpts", func(t *testing.T) { @@ -5276,7 +5276,7 @@ func Test_Client_Maintenance(t *testing.T) { exec := client.driver.GetExecutor() - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() client.PeriodicJobs().Add( NewPeriodicJob(cron.Every(15*time.Minute), func() (JobArgs, *InsertOpts) { @@ -5322,7 +5322,7 @@ func Test_Client_Maintenance(t *testing.T) { startClient(ctx, t, client) - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() svc := maintenance.GetService[*maintenance.PeriodicJobEnqueuer](client.queueMaintainer) svc.TestSignals.EnteredLoop.WaitOrTimeout() @@ -5395,7 +5395,7 @@ func Test_Client_Maintenance(t *testing.T) { startClient(ctx, t, client) - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() svc := maintenance.GetService[*maintenance.PeriodicJobEnqueuer](client.queueMaintainer) svc.TestSignals.EnteredLoop.WaitOrTimeout() @@ -5440,7 +5440,7 @@ func Test_Client_Maintenance(t *testing.T) { startClient(ctx, t, client) - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() svc := maintenance.GetService[*maintenance.PeriodicJobEnqueuer](client.queueMaintainer) svc.TestSignals.EnteredLoop.WaitOrTimeout() @@ -5535,7 +5535,7 @@ func Test_Client_Maintenance(t *testing.T) { startClient(ctx, t, client) - client.testSignals.electedLeader.WaitOrTimeout() + client.queueMaintainerLeader.TestSignals.ElectedLeader.WaitOrTimeout() qc := maintenance.GetService[*maintenance.QueueCleaner](client.queueMaintainer) qc.TestSignals.DeletedBatch.WaitOrTimeout() diff --git a/internal/maintenance/queue_maintainer_leader.go b/internal/maintenance/queue_maintainer_leader.go new file mode 100644 index 00000000..11697de7 --- /dev/null +++ b/internal/maintenance/queue_maintainer_leader.go @@ -0,0 +1,185 @@ +package maintenance + +import ( + "context" + "log/slog" + "sync" + + "github.com/riverqueue/river/internal/leadership" + "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/startstop" + "github.com/riverqueue/river/rivershared/testsignal" + "github.com/riverqueue/river/rivershared/util/serviceutil" + "github.com/riverqueue/river/rivershared/util/testutil" +) + +const queueMaintainerMaxStartAttempts = 3 + +// QueueMaintainerLeaderTestSignals are internal signals used exclusively in tests. +type QueueMaintainerLeaderTestSignals struct { + ElectedLeader testsignal.TestSignal[struct{}] // notifies when elected leader + StartError testsignal.TestSignal[error] // notifies on each failed queue maintainer start attempt + StartRetriesExhausted testsignal.TestSignal[struct{}] // notifies when all start retries have been exhausted +} + +func (ts *QueueMaintainerLeaderTestSignals) Init(tb testutil.TestingTB) { + ts.ElectedLeader.Init(tb) + ts.StartError.Init(tb) + ts.StartRetriesExhausted.Init(tb) +} + +// RequestResignFunc is a function that sends a notification requesting leader +// resignation. It's injected from the client because the notification mechanism +// depends on the driver, which the maintenance package doesn't know about. +type RequestResignFunc func(ctx context.Context) error + +// QueueMaintainerLeaderConfig is the configuration for QueueMaintainerLeader. +type QueueMaintainerLeaderConfig struct { + // ClientID is used for logging on leadership changes. + ClientID string + + // Elector provides leadership change notifications. + Elector *leadership.Elector + + // QueueMaintainer is the underlying maintainer to start/stop on leadership + // changes. + QueueMaintainer *QueueMaintainer + + // RequestResignFunc sends a notification requesting leader resignation. + RequestResignFunc RequestResignFunc +} + +// QueueMaintainerLeader listens for leadership changes and starts/stops the +// queue maintainer accordingly. It handles retries with exponential backoff on +// start failures, and requests leader resignation when all retries are +// exhausted. This is extracted to a separate struct because to get all the edge +// cases right, it ends up being a fair bit of code that would otherwise make +// Client fairly heavy. +type QueueMaintainerLeader struct { + baseservice.BaseService + startstop.BaseStartStop + + // exported for test purposes + TestSignals QueueMaintainerLeaderTestSignals + + config *QueueMaintainerLeaderConfig + + // epoch is incremented each time leadership is gained, giving each start + // goroutine a term number. mu serializes epoch checks with Stop calls so + // a stale goroutine cannot tear down a newer term's maintainer. + epoch int64 + mu sync.Mutex +} + +func NewQueueMaintainerLeader(archetype *baseservice.Archetype, config *QueueMaintainerLeaderConfig) *QueueMaintainerLeader { + return baseservice.Init(archetype, &QueueMaintainerLeader{ + config: config, + }) +} + +func (s *QueueMaintainerLeader) Start(ctx context.Context) error { + ctx, shouldStart, started, stopped := s.StartInit(ctx) + if !shouldStart { + return nil + } + + go func() { + started() + defer stopped() // this defer should come first so it's last out + + sub := s.config.Elector.Listen() + defer sub.Unlisten() + + // Cancel function for an in-progress start attempt. If leadership is + // lost while the start process is still retrying, used to abort it + // promptly instead of waiting for retries to finish. + var cancelStart context.CancelCauseFunc = func(_ error) {} + + for { + select { + case <-ctx.Done(): + cancelStart(context.Cause(ctx)) + return + + case notification := <-sub.C(): + s.Logger.DebugContext(ctx, s.Name+": Election change received", + slog.String("client_id", s.config.ClientID), slog.Bool("is_leader", notification.IsLeader)) + + switch { + case notification.IsLeader: + s.TestSignals.ElectedLeader.Signal(struct{}{}) + + // Start with retries in a separate goroutine so the + // leadership change loop remains responsive. + var startCtx context.Context + startCtx, cancelStart = context.WithCancelCause(ctx) + + s.mu.Lock() + s.epoch++ + epoch := s.epoch + s.mu.Unlock() + + go s.tryStart(startCtx, epoch) + + default: + // Cancel any in-progress start attempts before stopping. + // Send ErrStop so services like Reindexer run cleanup. + cancelStart(startstop.ErrStop) + cancelStart = func(_ error) {} + + s.config.QueueMaintainer.Stop() + } + } + } + }() + + return nil +} + +func (s *QueueMaintainerLeader) tryStart(ctx context.Context, epoch int64) { + var lastErr error + for attempt := 1; attempt <= queueMaintainerMaxStartAttempts; attempt++ { + if ctx.Err() != nil { + return + } + + if lastErr = s.config.QueueMaintainer.Start(ctx); lastErr == nil { + return + } + + s.Logger.ErrorContext(ctx, s.Name+": Error starting queue maintainer", + slog.String("err", lastErr.Error()), slog.Int("attempt", attempt)) + + s.TestSignals.StartError.Signal(lastErr) + + // Stop to fully reset state before retrying. The mutex serializes + // the epoch check with the increment in Start so a stale goroutine + // cannot tear down a newer term's maintainer. + s.mu.Lock() + stale := s.epoch != epoch + if !stale { + s.config.QueueMaintainer.Stop() + } + s.mu.Unlock() + if stale { + return + } + + if attempt < queueMaintainerMaxStartAttempts { + serviceutil.CancellableSleep(ctx, serviceutil.ExponentialBackoff(attempt, serviceutil.MaxAttemptsBeforeResetDefault)) + } + } + + if ctx.Err() != nil { + return + } + + s.Logger.ErrorContext(ctx, s.Name+": Queue maintainer failed to start after all attempts, requesting leader resignation", + slog.String("err", lastErr.Error())) + + s.TestSignals.StartRetriesExhausted.Signal(struct{}{}) + + if err := s.config.RequestResignFunc(ctx); err != nil { + s.Logger.ErrorContext(ctx, s.Name+": Error requesting leader resignation", slog.String("err", err.Error())) + } +} diff --git a/internal/maintenance/queue_maintainer_leader_test.go b/internal/maintenance/queue_maintainer_leader_test.go new file mode 100644 index 00000000..9865304c --- /dev/null +++ b/internal/maintenance/queue_maintainer_leader_test.go @@ -0,0 +1,134 @@ +package maintenance + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/internal/leadership" + "github.com/riverqueue/river/riverdbtest" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/river/rivershared/startstop" +) + +func TestQueueMaintainerLeader(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + setup := func(t *testing.T, maintainer *QueueMaintainer) *QueueMaintainerLeader { + t.Helper() + + var ( + dbPool = riversharedtest.DBPool(ctx, t) + driver = riverpgxv5.New(dbPool) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + archetype = riversharedtest.BaseServiceArchetype(t) + ) + + elector := leadership.NewElector(archetype, driver.GetExecutor(), nil, &leadership.Config{ + ClientID: "test_client_id", + Schema: schema, + }) + require.NoError(t, elector.Start(ctx)) + t.Cleanup(elector.Stop) + + leader := NewQueueMaintainerLeader(archetype, &QueueMaintainerLeaderConfig{ + ClientID: "test_client_id", + Elector: elector, + QueueMaintainer: maintainer, + RequestResignFunc: func(ctx context.Context) error { + return nil + }, + }) + leader.TestSignals.Init(t) + + return leader + } + + t.Run("StartsMaintainerOnLeadershipGain", func(t *testing.T) { + t.Parallel() + + testSvc := newTestService(t) + maintainer := NewQueueMaintainer(riversharedtest.BaseServiceArchetype(t), []startstop.Service{testSvc}) + maintainer.StaggerStartupDisable(true) + + leader := setup(t, maintainer) + + require.NoError(t, leader.Start(ctx)) + t.Cleanup(leader.Stop) + + leader.TestSignals.ElectedLeader.WaitOrTimeout() + testSvc.testSignals.started.WaitOrTimeout() + + leader.Stop() + testSvc.testSignals.returning.WaitOrTimeout() + }) + + t.Run("RetriesAndResignsOnStartFailure", func(t *testing.T) { + t.Parallel() + + var startAttempts atomic.Int64 + failingSvc := &failingStartService{startAttempts: &startAttempts} + + maintainer := NewQueueMaintainer(riversharedtest.BaseServiceArchetype(t), []startstop.Service{failingSvc}) + maintainer.StaggerStartupDisable(true) + + resignCalled := make(chan struct{}) + archetype := riversharedtest.BaseServiceArchetype(t) + + var ( + dbPool = riversharedtest.DBPool(ctx, t) + driver = riverpgxv5.New(dbPool) + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + ) + + elector := leadership.NewElector(archetype, driver.GetExecutor(), nil, &leadership.Config{ + ClientID: "test_client_id", + Schema: schema, + }) + require.NoError(t, elector.Start(ctx)) + t.Cleanup(elector.Stop) + + leader := NewQueueMaintainerLeader(archetype, &QueueMaintainerLeaderConfig{ + ClientID: "test_client_id", + Elector: elector, + QueueMaintainer: maintainer, + RequestResignFunc: func(ctx context.Context) error { + close(resignCalled) + return nil + }, + }) + leader.TestSignals.Init(t) + + require.NoError(t, leader.Start(ctx)) + t.Cleanup(leader.Stop) + + leader.TestSignals.ElectedLeader.WaitOrTimeout() + + for range queueMaintainerMaxStartAttempts { + err := leader.TestSignals.StartError.WaitOrTimeout() + require.EqualError(t, err, "start error") + } + + leader.TestSignals.StartRetriesExhausted.WaitOrTimeout() + riversharedtest.WaitOrTimeout(t, resignCalled) + require.Equal(t, int64(queueMaintainerMaxStartAttempts), startAttempts.Load()) + }) +} + +// failingStartService is a service whose Start always returns an error. +type failingStartService struct { + startstop.BaseStartStop + + startAttempts *atomic.Int64 +} + +func (s *failingStartService) Start(ctx context.Context) error { + s.startAttempts.Add(1) + return errors.New("start error") +}