Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pkg/utils/retry/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,28 @@ func (s *Strategy[R]) Do(ctx context.Context, lggr logger.Logger, fn func(ctx co
}

// Track the number of retries
var lastKnownErr error
for numRetries := int(s.Backoff.Attempt()); ; numRetries++ {
if s.MaxRetries > 0 {
if numRetries > int(s.MaxRetries) {
var empty R
return empty, fmt.Errorf("max retry attempts reached")
return empty, fmt.Errorf("max retry attempts reached {retryID=%s, numRetries=%d}, last known err: %w", retryID, numRetries, lastKnownErr)
}
}

result, err := fn(ctx)
if err == nil {
return result, nil
}
lastKnownErr = err

wait := s.Backoff.Duration()
message := fmt.Sprintf("Failed to execute function, retrying in %s ...", wait)
lggr.Warnw(message, "wait", wait, "numRetries", numRetries, "retryID", retryID, "err", err)

select {
case <-ctx.Done():
return result, fmt.Errorf("context done while executing function {retryID=%s, numRetries=%d}: %w", retryID, numRetries, ctx.Err())
return result, fmt.Errorf("context done while executing function {retryID=%s, numRetries=%d}, last known err: %w: %w", retryID, numRetries, lastKnownErr, ctx.Err())
case <-time.After(wait):
// Continue with the next retry
}
Expand Down
44 changes: 37 additions & 7 deletions pkg/utils/retry/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package retry
import (
"context"
"errors"
"fmt"
"testing"
"time"

Expand All @@ -17,7 +18,7 @@ type testCase struct {
name string
fn exampleFunc
expected string
errMsg string
errMsg []string
timeout time.Duration
strategy *Strategy[string]
}
Expand All @@ -39,7 +40,7 @@ func TestWithRetry(t *testing.T) {
fn: func(ctx context.Context) (string, error) {
return "", errors.New("permanent error")
},
errMsg: "context done while executing function",
errMsg: []string{"context done while executing function", "permanent error"},
timeout: 100 * time.Millisecond,
},
{
Expand Down Expand Up @@ -69,7 +70,7 @@ func TestWithRetry(t *testing.T) {
return "eventual success", nil
}
}(),
errMsg: "context done while executing function",
errMsg: []string{"context done while executing function", "temporary error"},
timeout: 100 * time.Millisecond,
},
{
Expand Down Expand Up @@ -101,7 +102,7 @@ func TestWithRetry(t *testing.T) {
return "eventual success", nil
}
}(),
errMsg: "context done while executing function",
errMsg: []string{"context done while executing function", "temporary error"},
timeout: 1 * time.Second,
},
{
Expand Down Expand Up @@ -136,9 +137,36 @@ func TestWithRetry(t *testing.T) {
strategy: &Strategy[string]{
MaxRetries: 1,
},
errMsg: "max retry attempts reached",
errMsg: []string{"max retry attempts reached", "numRetries=2", "temporary error"},
timeout: 1 * time.Second,
},
{
name: "context timeout surfaces the last callback error, not earlier ones",
fn: func() exampleFunc {
attempt := 0
return func(ctx context.Context) (string, error) {
attempt++
return "", fmt.Errorf("error on attempt %d", attempt)
}
}(),
errMsg: []string{"context done while executing function", "error on attempt"},
timeout: 300 * time.Millisecond,
},
{
name: "max retries surfaces the last callback error, not earlier ones",
fn: func() exampleFunc {
attempt := 0
return func(ctx context.Context) (string, error) {
attempt++
return "", fmt.Errorf("error on attempt %d", attempt)
}
}(),
strategy: &Strategy[string]{
MaxRetries: 3,
},
errMsg: []string{"max retry attempts reached", "error on attempt 4"},
timeout: 5 * time.Second,
},
}

for _, tt := range tests {
Expand All @@ -157,9 +185,11 @@ func TestWithRetry(t *testing.T) {
result, err = tt.strategy.Do(ctx, lggr, tt.fn)
}

if tt.errMsg != "" {
if len(tt.errMsg) > 0 {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
for _, msg := range tt.errMsg {
require.Contains(t, err.Error(), msg)
}
} else {
require.NoError(t, err)
require.Equal(t, tt.expected, result)
Expand Down
Loading