diff --git a/README.md b/README.md index 25d4511..444a5bf 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,61 @@ if err != nil { } ``` +### Drain on Stop + +By default, `Stop` cancels the runFunc's context, which aborts in-flight +work. For workers that own external calls that must complete (e.g. an +HTTP request that creates remote state), use `WithDrain` to switch to +"signal-and-wait" semantics: `Stop` closes the channel returned by +`Stopping(ctx)` and waits up to the drain timeout for runFunc to return +on its own. If the timeout elapses, `Stop` falls back to cancelling the +context and returns `ErrDrainTimedOut` once runFunc exits. + +Always select on both `<-ctx.Done()` and `<-runnable.Stopping(ctx)` — +`Stopping` signals only `Stop`; outer-context cancellation still arrives +via `ctx.Done()` and a loop that ignores it will hang. + +```go +r := runnable.New(func(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-runnable.Stopping(ctx): + return nil // finish in-flight work, then return + case <-time.After(time.Second): + doWork(ctx) + } + } +}, runnable.WithDrain(10*time.Second)) +``` + +### Ticker + +`NewTicker` wraps the standard "select-loop on a `time.Ticker`" pattern. +It composes with `WithDrain` (let the current tick finish on Stop) and +`WithRecoverer` (catch panics in the tick body). + +```go +r := runnable.NewTicker( + 30*time.Second, + func(ctx context.Context) error { + return reconcile(ctx) + }, + runnable.WithDrain(10*time.Second), +) + +go r.Run(ctx) + +// On shutdown: +stopCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) +defer cancel() +r.Stop(stopCtx) // drains the in-flight tick before returning +``` + +A full SIGTERM-safe shape (ticker + drain + recoverer + signal.NotifyContext) +lives in [`examples/ticker-with-drain`](examples/ticker-with-drain/main.go). + ### Runnable Object ```go package main diff --git a/examples/main.go b/examples/main.go index e1e710a..b799017 100644 --- a/examples/main.go +++ b/examples/main.go @@ -33,7 +33,6 @@ func (m *Monitor) run(ctx context.Context) error { time.Sleep(1 * time.Second) fmt.Println("Monitoring...") } - return nil } func main() { @@ -92,7 +91,8 @@ func main() { // simple function with timeout fmt.Println("Simple function with timeout...") - ctxWithTimeout, _ := context.WithTimeout(context.Background(), 5*time.Second) + ctxWithTimeout, cancelTimeout := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelTimeout() err = runnable.New(func(ctx context.Context) error { fmt.Println("Starting...") defer fmt.Println("Stopping...") diff --git a/examples/ticker-with-drain/main.go b/examples/ticker-with-drain/main.go new file mode 100644 index 0000000..5eb85a2 --- /dev/null +++ b/examples/ticker-with-drain/main.go @@ -0,0 +1,83 @@ +// Example: a periodic reconciler that drains gracefully on SIGTERM. +// +// Shape: NewTicker + WithDrain + WithRecoverer + signal.NotifyContext. +// Copy-paste this into a service's cmd/.../main.go and replace the +// reconcile body with your work. +package main + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/0xsequence/runnable" +) + +type stderrReporter struct{} + +func (stderrReporter) Report(ctx context.Context, rec interface{}) { + fmt.Fprintf(os.Stderr, "panic recovered: %v\n", rec) +} + +type stderrPrinter struct{} + +func (stderrPrinter) Print(ctx context.Context, callstack []byte) { + _, _ = os.Stderr.Write(callstack) +} + +func reconcile(ctx context.Context) error { + // Pretend this is an HTTP call to an external system that must not + // be aborted mid-request when SIGTERM fires. Under WithDrain, Stop + // waits for this tick to finish before tearing down the Runnable. + fmt.Println("tick: reconciling...") + time.Sleep(500 * time.Millisecond) + fmt.Println("tick: done") + return nil +} + +func main() { + sigCtx, stopSig := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + defer stopSig() + + rc := runnable.NewTicker( + 2*time.Second, + reconcile, + runnable.WithDrain(10*time.Second), + runnable.WithRecoverer(stderrReporter{}, stderrPrinter{}), + ) + + // Run with a pristine ctx — if Run received sigCtx, SIGTERM would + // cancel runFunc's ctx directly and the ticker would exit before + // Stop ever closed Stopping(ctx), defeating WithDrain. Stop is the + // only thing that should drive shutdown of a drain-enabled worker. + runErr := make(chan error, 1) + go func() { + runErr <- rc.Run(context.Background()) + }() + + // Wait for either a shutdown signal or an early worker exit + // (tick error, recovered panic). Without the runErr branch, main + // would block on sigCtx forever after the worker died. + select { + case <-sigCtx.Done(): + fmt.Println("shutdown: draining in-flight tick...") + stopCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + if err := rc.Stop(stopCtx); err != nil { + fmt.Fprintf(os.Stderr, "stop: %v\n", err) + } + if err := <-runErr; err != nil && !errors.Is(err, context.Canceled) { + fmt.Fprintf(os.Stderr, "reconciler stopped: %v\n", err) + os.Exit(1) + } + case err := <-runErr: + if err != nil && !errors.Is(err, context.Canceled) { + fmt.Fprintf(os.Stderr, "reconciler stopped: %v\n", err) + os.Exit(1) + } + } +} diff --git a/runnable.go b/runnable.go index ca62fef..eb57ee7 100644 --- a/runnable.go +++ b/runnable.go @@ -4,11 +4,13 @@ import ( "context" "fmt" "sync" + "time" ) var ( ErrAlreadyRunning = fmt.Errorf("already running") ErrNotRunning = fmt.Errorf("not running") + ErrDrainTimedOut = fmt.Errorf("drain timed out") ) type Option interface { @@ -28,6 +30,10 @@ type runnable struct { runCancel context.CancelFunc runStop chan bool + drainEnabled bool + drainTimeout time.Duration + stoppingChan chan struct{} + isRunning bool onStart func() onStop func() @@ -92,6 +98,10 @@ func (r *runnable) Run(ctx context.Context) error { r.runStop = make(chan bool) runCtx := r.runCtx + if r.drainEnabled { + r.stoppingChan = make(chan struct{}) + runCtx = context.WithValue(runCtx, stoppingKey{}, (<-chan struct{})(r.stoppingChan)) + } r.mu.Unlock() defer func() { @@ -134,14 +144,62 @@ func (r *runnable) Stop(ctx context.Context) error { } runStop := r.runStop + // Snapshot runCancel under the lock — the field is overwritten by + // the next Run, so reading r.runCancel() after waiting can cancel + // a *future* runnable that started after this Stop began draining. + runCancel := r.runCancel + drainEnabled := r.drainEnabled + drainTimeout := r.drainTimeout + stoppingChan := r.stoppingChan + r.stoppingChan = nil // first-caller wins; subsequent Stops see nil r.mu.Unlock() - r.runCancel() + // Concurrent Stop with drain enabled: another caller is already + // driving the drain. Wait for its outcome rather than calling + // runCancel ourselves — that would hard-cancel the runCtx and + // defeat the drain the primary caller is honoring. If our ctx + // expires first, escalate to runCancel so the shortest deadline + // among concurrent callers wins. + if drainEnabled && stoppingChan == nil { + select { + case <-runStop: + return nil + case <-ctx.Done(): + runCancel() + return ctx.Err() + } + } + + var drainTimedOut bool + if drainEnabled { + close(stoppingChan) + // Use a standalone timer so the drain budget is independent of + // the caller's ctx — otherwise a caller ctx shorter than + // drainTimeout makes <-ctx.Done() and the drain expiry race in + // the same select. + drainTimer := time.NewTimer(drainTimeout) + select { + case <-runStop: + drainTimer.Stop() + return nil + case <-drainTimer.C: + drainTimedOut = true + case <-ctx.Done(): + drainTimer.Stop() + // Caller deadline elapsed during drain; fall through so + // r.runCancel() still fires before we return ctx.Err(). + } + } + + runCancel() select { case <-ctx.Done(): return ctx.Err() case <-runStop: + if drainTimedOut { + return ErrDrainTimedOut + } return nil } } diff --git a/runnable_group_test.go b/runnable_group_test.go index 6efa922..fb70ed9 100644 --- a/runnable_group_test.go +++ b/runnable_group_test.go @@ -48,9 +48,7 @@ func TestNewGroup(t *testing.T) { // Create a new group group := NewGroup( New(func(ctx context.Context) error { - select { - case <-ctx.Done(): - } + <-ctx.Done() return nil }), New(func(ctx context.Context) error { diff --git a/runnable_test.go b/runnable_test.go index 045be86..ea1b2a8 100644 --- a/runnable_test.go +++ b/runnable_test.go @@ -43,11 +43,8 @@ func TestRunnable(t *testing.T) { r := New(func(ctx context.Context) error { started <- struct{}{} time.Sleep(2 * time.Second) - - select { - case <-ctx.Done(): - return ctx.Err() - } + <-ctx.Done() + return ctx.Err() }) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) @@ -81,18 +78,20 @@ func TestRunnable(t *testing.T) { assert.Equal(t, false, r.IsRunning()) }) - t.Run("runnable, stop timeout", func(t *testing.T) { + t.Run("runnable, stop timeout proves no drain behavior", func(t *testing.T) { started := make(chan struct{}) + ctxCancelObserved := make(chan struct{}) r := New(func(ctx context.Context) error { started <- struct{}{} + <-ctx.Done() + close(ctxCancelObserved) time.Sleep(2 * time.Second) - return nil + return ctx.Err() }) go func() { - err := r.Run(context.Background()) - require.NoError(t, err) + _ = r.Run(context.Background()) }() <-started @@ -101,7 +100,15 @@ func TestRunnable(t *testing.T) { stopCtx, stopCancel := context.WithTimeout(context.Background(), 1*time.Second) defer stopCancel() err := r.Stop(stopCtx) - require.Error(t, err, context.DeadlineExceeded) + require.ErrorIs(t, err, context.DeadlineExceeded) + + // Without WithDrain, Stop cancels runFunc's ctx immediately. + select { + case <-ctxCancelObserved: + case <-time.After(100 * time.Millisecond): + t.Fatal("expected runFunc's ctx to be cancelled when Stop fires without WithDrain") + } + assert.Equal(t, true, r.IsRunning()) }) } diff --git a/ticker.go b/ticker.go new file mode 100644 index 0000000..9a268d0 --- /dev/null +++ b/ticker.go @@ -0,0 +1,56 @@ +package runnable + +import ( + "context" + "time" +) + +// NewTicker returns a Runnable that calls tick once per interval until +// ctx is cancelled, Stop is called, or tick returns a non-nil error. +// +// When the Runnable is configured WithDrain, an in-flight tick is +// allowed to finish before Run returns; the loop exits without firing +// a new tick. Without WithDrain, Stop cancels ctx and any in-flight +// tick observes the cancellation through ctx.Done(). +// +// tick should respect ctx.Done() for cancellation. To make in-flight +// external calls survive shutdown under WithDrain, tick should derive +// per-call timeouts via context.WithoutCancel(ctx) so its work is not +// affected by either Stop's drain signal or the Runnable's ctx cancel. +// +// Composing with WithRetry resets the ticker cadence on every retry: +// a tick error bails the loop, WithRetry re-enters runFunc, and the +// next tick fires `interval` after the retry — not at the original +// cadence. If you need stable cadence with transient-error tolerance, +// handle retries inside `tick` instead. +func NewTicker(interval time.Duration, tick func(ctx context.Context) error, opts ...Option) Runnable { + return New(func(ctx context.Context) error { + t := time.NewTicker(interval) + defer t.Stop() + + stopping := Stopping(ctx) // nil when WithDrain not used + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-stopping: + return nil + case <-t.C: + // Re-check shutdown signals before firing a new tick: + // when a tick takes longer than the interval, t.C may + // have ready ticks queued from before Stop was called. + select { + case <-ctx.Done(): + return ctx.Err() + case <-stopping: + return nil + default: + } + if err := tick(ctx); err != nil { + return err + } + } + } + }, opts...) +} diff --git a/ticker_test.go b/ticker_test.go new file mode 100644 index 0000000..d4ea362 --- /dev/null +++ b/ticker_test.go @@ -0,0 +1,135 @@ +package runnable + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewTicker(t *testing.T) { + t.Run("fires on interval", func(t *testing.T) { + var count atomic.Int32 + + r := NewTicker(50*time.Millisecond, func(ctx context.Context) error { + count.Add(1) + return nil + }) + + go func() { + _ = r.Run(context.Background()) + }() + + time.Sleep(175 * time.Millisecond) + + err := r.Stop(context.Background()) + require.NoError(t, err) + + c := count.Load() + assert.GreaterOrEqual(t, c, int32(2)) + assert.LessOrEqual(t, c, int32(4)) + }) + + t.Run("Stop with drain allows current tick to finish", func(t *testing.T) { + tickStarted := make(chan struct{}, 1) + var completed atomic.Int32 + + r := NewTicker(20*time.Millisecond, func(ctx context.Context) error { + select { + case tickStarted <- struct{}{}: + default: + } + time.Sleep(200 * time.Millisecond) + completed.Add(1) + return nil + }, WithDrain(1*time.Second)) + + go func() { + _ = r.Run(context.Background()) + }() + + <-tickStarted + + start := time.Now() + err := r.Stop(context.Background()) + elapsed := time.Since(start) + require.NoError(t, err) + + assert.GreaterOrEqual(t, completed.Load(), int32(1), "in-flight tick should complete") + assert.Less(t, elapsed, 500*time.Millisecond) + }) + + t.Run("Stop without drain cancels in-flight tick", func(t *testing.T) { + tickStarted := make(chan struct{}, 1) + tickErr := make(chan error, 1) + + r := NewTicker(20*time.Millisecond, func(ctx context.Context) error { + select { + case tickStarted <- struct{}{}: + default: + } + <-ctx.Done() + tickErr <- ctx.Err() + return ctx.Err() + }) + + runDone := make(chan error, 1) + go func() { + runDone <- r.Run(context.Background()) + }() + + <-tickStarted + err := r.Stop(context.Background()) + require.NoError(t, err) + + select { + case e := <-tickErr: + require.ErrorIs(t, e, context.Canceled) + case <-time.After(time.Second): + t.Fatal("tick did not observe ctx cancellation") + } + + select { + case e := <-runDone: + require.ErrorIs(t, e, context.Canceled) + case <-time.After(time.Second): + t.Fatal("Run did not return") + } + }) + + t.Run("tick error aborts loop", func(t *testing.T) { + sentinel := errors.New("boom") + var count atomic.Int32 + + r := NewTicker(20*time.Millisecond, func(ctx context.Context) error { + if count.Add(1) == 2 { + return sentinel + } + return nil + }) + + err := r.Run(context.Background()) + require.ErrorIs(t, err, sentinel) + assert.Equal(t, int32(2), count.Load()) + }) + + t.Run("respects outer ctx cancel", func(t *testing.T) { + var count atomic.Int32 + + r := NewTicker(20*time.Millisecond, func(ctx context.Context) error { + count.Add(1) + return nil + }) + + ctx, cancel := context.WithTimeout(context.Background(), 75*time.Millisecond) + defer cancel() + + err := r.Run(ctx) + require.ErrorIs(t, err, context.DeadlineExceeded) + assert.False(t, r.IsRunning()) + }) +} diff --git a/with_drain.go b/with_drain.go new file mode 100644 index 0000000..306be8a --- /dev/null +++ b/with_drain.go @@ -0,0 +1,46 @@ +package runnable + +import ( + "context" + "time" +) + +type stoppingKey struct{} + +// Stopping returns a channel that closes when Stop has been called on +// the Runnable that owns ctx. runFunc implementations under WithDrain +// should select on it and return cleanly without cancelling in-flight +// work. +// +// Always also select on ctx.Done() — Stopping signals only Stop; +// outer-context cancellation (e.g. the ctx passed to Run was cancelled +// directly) still arrives via ctx.Done(). A loop that selects only on +// Stopping(ctx) will hang on outer-ctx cancel. +// +// Returns a nil channel when ctx is not associated with a drain-enabled +// Runnable — receiving from a nil channel blocks forever, which is the +// correct no-op for callers that opt into drain semantics only when +// configured. +func Stopping(ctx context.Context) <-chan struct{} { + ch, _ := ctx.Value(stoppingKey{}).(<-chan struct{}) + return ch +} + +type withDrain struct { + timeout time.Duration +} + +// WithDrain switches Stop's behavior from "cancel runFunc's ctx" to +// "close Stopping(ctx) and wait up to timeout for runFunc to return on +// its own." After the timeout elapses, Stop falls back to cancelling +// the ctx as before (preserving the existing escape hatch for stuck +// runFuncs). Use this when runFunc owns in-flight external calls that +// must drain rather than abort on shutdown. +func WithDrain(timeout time.Duration) Option { + return &withDrain{timeout: timeout} +} + +func (w *withDrain) apply(r *runnable) { + r.drainEnabled = true + r.drainTimeout = w.timeout +} diff --git a/with_drain_test.go b/with_drain_test.go new file mode 100644 index 0000000..55ec879 --- /dev/null +++ b/with_drain_test.go @@ -0,0 +1,405 @@ +package runnable + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWithDrain(t *testing.T) { + t.Run("Stop waits for runFunc to return", func(t *testing.T) { + started := make(chan struct{}) + runFuncErr := make(chan error, 1) + + r := New(func(ctx context.Context) error { + close(started) + <-Stopping(ctx) + time.Sleep(200 * time.Millisecond) + // Return naturally without observing ctx cancellation. + if ctx.Err() != nil { + return ctx.Err() + } + return nil + }, WithDrain(1*time.Second)) + + go func() { + runFuncErr <- r.Run(context.Background()) + }() + + <-started + assert.True(t, r.IsRunning()) + + start := time.Now() + err := r.Stop(context.Background()) + elapsed := time.Since(start) + require.NoError(t, err) + assert.False(t, r.IsRunning()) + assert.GreaterOrEqual(t, elapsed, 200*time.Millisecond) + assert.Less(t, elapsed, 500*time.Millisecond) + + select { + case err := <-runFuncErr: + require.NoError(t, err, "runFunc should return naturally, not via ctx cancellation") + case <-time.After(time.Second): + t.Fatal("runFunc did not return") + } + }) + + t.Run("Stop returns ErrDrainTimedOut on fall-through", func(t *testing.T) { + started := make(chan struct{}) + runFuncErr := make(chan error, 1) + + r := New(func(ctx context.Context) error { + close(started) + <-ctx.Done() + return ctx.Err() + }, WithDrain(100*time.Millisecond)) + + go func() { + runFuncErr <- r.Run(context.Background()) + }() + + <-started + + stopCtx, stopCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer stopCancel() + err := r.Stop(stopCtx) + require.ErrorIs(t, err, ErrDrainTimedOut) + assert.False(t, r.IsRunning()) + + select { + case err := <-runFuncErr: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(time.Second): + t.Fatal("runFunc did not return") + } + }) + + t.Run("Stop returns DeadlineExceeded when runFunc stuck", func(t *testing.T) { + started := make(chan struct{}) + release := make(chan struct{}) + + r := New(func(ctx context.Context) error { + close(started) + <-release + return nil + }, WithDrain(50*time.Millisecond)) + + go func() { + _ = r.Run(context.Background()) + }() + + <-started + + stopCtx, stopCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer stopCancel() + err := r.Stop(stopCtx) + require.ErrorIs(t, err, context.DeadlineExceeded) + + // Release the runFunc so the goroutine can exit cleanly. + close(release) + }) + + t.Run("outer ctx cancel still propagates", func(t *testing.T) { + started := make(chan struct{}) + runFuncErr := make(chan error, 1) + + r := New(func(ctx context.Context) error { + close(started) + <-ctx.Done() + return ctx.Err() + }, WithDrain(1*time.Second)) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + runFuncErr <- r.Run(ctx) + }() + + <-started + cancel() + + select { + case err := <-runFuncErr: + require.True(t, errors.Is(err, context.Canceled)) + case <-time.After(time.Second): + t.Fatal("runFunc did not exit on outer ctx cancel") + } + assert.False(t, r.IsRunning()) + }) + + t.Run("Stop forces cancel when caller ctx expires during drain", func(t *testing.T) { + started := make(chan struct{}) + runFuncDone := make(chan struct{}) + + // runFunc respects its own ctx but not Stopping(ctx). Without + // the independent drain timer, Stop with a caller ctx shorter + // than drainTimeout could return ctx.Err() before r.runCancel() + // fired, leaving the runnable alive. + r := New(func(ctx context.Context) error { + close(started) + <-ctx.Done() + close(runFuncDone) + return ctx.Err() + }, WithDrain(10*time.Second)) + + go func() { + _ = r.Run(context.Background()) + }() + + <-started + + stopCtx, stopCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer stopCancel() + err := r.Stop(stopCtx) + require.ErrorIs(t, err, context.DeadlineExceeded) + + select { + case <-runFuncDone: + case <-time.After(2 * time.Second): + t.Fatal("runnable was not force-cancelled when caller ctx expired during drain") + } + }) + + t.Run("concurrent Stop preserves drain semantics", func(t *testing.T) { + started := make(chan struct{}) + drainObserved := make(chan struct{}) + var ctxCancelObserved atomic.Bool + + // runFunc must exit via Stopping(ctx). If a concurrent Stop + // falls through to r.runCancel(), ctx.Done() fires and the + // drain semantics are violated. + r := New(func(ctx context.Context) error { + close(started) + select { + case <-Stopping(ctx): + close(drainObserved) + return nil + case <-ctx.Done(): + ctxCancelObserved.Store(true) + return ctx.Err() + } + }, WithDrain(2*time.Second)) + + go func() { + _ = r.Run(context.Background()) + }() + + <-started + + const callers = 10 + var wg sync.WaitGroup + errs := make([]error, callers) + for i := 0; i < callers; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + errs[i] = r.Stop(context.Background()) + }() + } + wg.Wait() + + // Each Stop must return either nil (drove or waited on the + // drain) or ErrNotRunning (Run already exited before this + // caller grabbed the lock). No double-close panic. + for _, err := range errs { + if err != nil { + require.ErrorIs(t, err, ErrNotRunning) + } + } + + select { + case <-drainObserved: + default: + t.Fatal("runFunc never observed Stopping(ctx); drain was bypassed by concurrent Stop") + } + assert.False(t, ctxCancelObserved.Load(), "drain semantics violated: runCtx was hard-cancelled by a concurrent Stop") + assert.False(t, r.IsRunning()) + }) + + t.Run("secondary Stop with shorter deadline escalates runCancel", func(t *testing.T) { + started := make(chan struct{}) + runFuncDone := make(chan struct{}) + + // runFunc waits only on ctx.Done() (ignores Stopping). Without + // escalation, Stop B's deadline expires but the runnable keeps + // draining for the full drainTimeout (5s). + r := New(func(ctx context.Context) error { + close(started) + <-ctx.Done() + close(runFuncDone) + return ctx.Err() + }, WithDrain(5*time.Second)) + + go func() { + _ = r.Run(context.Background()) + }() + + <-started + + // Stop A: no deadline; primary, drives drain. + aDone := make(chan error, 1) + go func() { + aDone <- r.Stop(context.Background()) + }() + + time.Sleep(20 * time.Millisecond) + + // Stop B: 100ms deadline; secondary. Must escalate so runFunc + // exits within the caller's budget. + bCtx, bCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer bCancel() + start := time.Now() + bErr := r.Stop(bCtx) + bElapsed := time.Since(start) + require.ErrorIs(t, bErr, context.DeadlineExceeded) + assert.Less(t, bElapsed, 500*time.Millisecond, "Stop B should not wait beyond its own deadline") + + select { + case <-runFuncDone: + case <-time.After(time.Second): + t.Fatal("runnable was not force-cancelled when secondary Stop's ctx expired") + } + + select { + case err := <-aDone: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("Stop A did not return after runFunc exited") + } + }) + + t.Run("runnable can be re-Run after a concurrent-Stop lifecycle", func(t *testing.T) { + // Lifecycle survival smoke test: a runnable that's been + // stopped via concurrent Stops (including one with an + // already-cancelled ctx hitting the runCancel escalation path) + // can be re-Run on the same instance and complete cleanly. + // + // This does NOT deterministically cover the runCancel-snapshot + // race in runnable.go (where a stale Stop could in principle + // reach r.runCancel after Run has overwritten the field). That + // race requires pausing the secondary Stop between its lock + // release and runCancel call while a fresh Run executes — + // achievable only via testing/synctest or a runtime hook. + // Both are out of scope here; the snapshot fix is verified by + // inspection, not this test. + r := New(func(ctx context.Context) error { + select { + case <-Stopping(ctx): + return nil + case <-ctx.Done(): + return ctx.Err() + } + }, WithDrain(1*time.Second)) + + go func() { + _ = r.Run(context.Background()) + }() + + for !r.IsRunning() { + time.Sleep(time.Millisecond) + } + + // Primary Stop, no deadline — drives drain. + primaryDone := make(chan error, 1) + go func() { + primaryDone <- r.Stop(context.Background()) + }() + + // Secondary Stop with an already-cancelled ctx — exercises + // the ctx.Done() escalation path that calls runCancel. + cancelledCtx, cancel := context.WithCancel(context.Background()) + cancel() + _ = r.Stop(cancelledCtx) + + <-primaryDone + for r.IsRunning() { + time.Sleep(time.Millisecond) + } + + // Round 2 — same runnable, fresh Run. Should run undisturbed + // until we Stop it. + round2Done := make(chan error, 1) + go func() { + round2Done <- r.Run(context.Background()) + }() + + select { + case err := <-round2Done: + t.Fatalf("round-2 runnable exited prematurely: %v", err) + case <-time.After(150 * time.Millisecond): + } + + require.NoError(t, r.Stop(context.Background())) + select { + case err := <-round2Done: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("round-2 runnable did not exit after Stop") + } + }) + + t.Run("WithRetry stops retrying after Stopping fires", func(t *testing.T) { + started := make(chan struct{}, 1) + var attempts atomic.Int32 + + // runFunc errors transiently every time. Without the + // Stopping-aware retry guard, WithRetry keeps re-entering + // runFunc after Stop is called. + r := New(func(ctx context.Context) error { + select { + case started <- struct{}{}: + default: + } + attempts.Add(1) + <-Stopping(ctx) + return errors.New("transient") + }, WithDrain(2*time.Second), WithRetry(100, ResetNever)) + + runDone := make(chan error, 1) + go func() { + runDone <- r.Run(context.Background()) + }() + + <-started + + require.NoError(t, r.Stop(context.Background())) + + select { + case <-runDone: + case <-time.After(time.Second): + t.Fatal("Run did not return after Stop") + } + + // Exactly one attempt should have run — the retry wrapper + // must observe Stopping and abandon further attempts. + assert.Equal(t, int32(1), attempts.Load(), "retry continued after Stop drained") + }) + + t.Run("Stopping returns nil when not configured", func(t *testing.T) { + var observed bool + + r := New(func(ctx context.Context) error { + ch := Stopping(ctx) + // Selecting on nil channel blocks forever; default branch runs. + select { + case <-ch: + observed = true + default: + observed = ch == nil + } + return nil + }) + + err := r.Run(context.Background()) + require.NoError(t, err) + assert.True(t, observed, "Stopping(ctx) should be nil without WithDrain") + }) +} diff --git a/with_recoverer_test.go b/with_recoverer_test.go index e7cb5dd..8f656b2 100644 --- a/with_recoverer_test.go +++ b/with_recoverer_test.go @@ -14,7 +14,7 @@ type InMemoryReporter struct { } func (i *InMemoryReporter) Report(ctx context.Context, rec interface{}) { - i.logs = append(i.logs, fmt.Sprintf("%s", rec.(string))) + i.logs = append(i.logs, fmt.Sprintf("%v", rec)) } func TestWithRecoverer(t *testing.T) { @@ -25,7 +25,6 @@ func TestWithRecoverer(t *testing.T) { fn := func(ctx context.Context) error { defer func() { counter++ }() panic("something went wrong") - return nil } r := New(fn, WithRecoverer(&reporter, nil)) @@ -42,7 +41,6 @@ func TestWithRecoverer(t *testing.T) { r := New(func(ctx context.Context) error { started <- struct{}{} panic("something went wrong") - return nil }, WithRecoverer(reporter, nil)) go func() { @@ -82,7 +80,6 @@ func TestWithRecoverer(t *testing.T) { r := New(func(ctx context.Context) error { started <- struct{}{} panic("something went wrong") - return nil }, WithRecoverer(reporter, nil), WithStatus("test", store)) go func() { diff --git a/with_retry.go b/with_retry.go index 97e0bbd..ee67e70 100644 --- a/with_retry.go +++ b/with_retry.go @@ -11,8 +11,6 @@ const ResetNever time.Duration = 0 type withRetry struct { maxRetries int resetAfter time.Duration - - lastTime time.Time } func WithRetry(maxRetries int, resetAfter time.Duration) Option { @@ -25,12 +23,15 @@ func WithRetry(maxRetries int, resetAfter time.Duration) Option { func (w *withRetry) apply(r *runnable) { runFunc := r.runFunc r.runFunc = func(ctx context.Context) error { + // lastTime is per-Run-cycle: a fresh Run after Stop should not + // inherit stale timing state from the prior cycle. + var lastTime time.Time var err error for i := 0; i < w.maxRetries; i++ { - if w.resetAfter != ResetNever && time.Since(w.lastTime) > w.resetAfter { + if w.resetAfter != ResetNever && time.Since(lastTime) > w.resetAfter { i = 0 } - w.lastTime = time.Now() + lastTime = time.Now() if i > 0 { if r.onStart != nil { @@ -46,6 +47,17 @@ func (w *withRetry) apply(r *runnable) { return err } + // Don't retry once Stop has been called via WithDrain — + // the retry wrapper would otherwise re-enter runFunc and + // start fresh work mid-shutdown, defeating drain semantics. + // When WithDrain is not used, Stopping(ctx) is nil and the + // default branch runs (no behavior change). + select { + case <-Stopping(ctx): + return err + default: + } + if i > 0 { if r.onStop != nil { r.onStop() diff --git a/with_retry_test.go b/with_retry_test.go index 6166473..c4781d5 100644 --- a/with_retry_test.go +++ b/with_retry_test.go @@ -58,4 +58,5 @@ func TestWithRetry(t *testing.T) { require.NoError(t, err) assert.Equal(t, 6, counter) }) + }