diff --git a/github/transport.go b/github/transport.go index d397b02503..b0ee83a915 100644 --- a/github/transport.go +++ b/github/transport.go @@ -283,7 +283,9 @@ func (t *RetryTransport) RoundTrip(req *http.Request) (*http.Response, error) { return resp, err } - time.Sleep(t.retryDelay) + if retry < t.maxRetries { + sleep(req.Context(), t.retryDelay) + } } return resp, err diff --git a/github/transport_test.go b/github/transport_test.go index eb2b07ba28..442dadee53 100644 --- a/github/transport_test.go +++ b/github/transport_test.go @@ -502,6 +502,71 @@ func TestRetryTransport_retry_post_success(t *testing.T) { } } +func TestRetryTransport_cancelled(t *testing.T) { + ts := githubApiMock([]*mockResponse{ + { + ExpectedUri: "/repos/test/blah", + ResponseBody: `{ + "message": "internal server error" +}`, + StatusCode: 500, + }, + }) + defer ts.Close() + httpClient := http.DefaultClient + httpClient.Transport = NewRetryTransport(http.DefaultTransport, + WithMaxRetries(1), + WithRetryDelay(10*time.Second), + ) + client := github.NewClient(httpClient) + u, _ := url.Parse(ts.URL + "/") + client.BaseURL = u + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + start := time.Now() + _, _, err := client.Repositories.Get(ctx, "test", "blah") + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Expected context.DeadlineExceeded, got: %v", err) + } + if time.Since(start) > time.Second { + t.Fatalf("Waited longer than expected: %s", time.Since(start)) + } +} + +func TestRetryTransport_no_sleep_after_last_retry(t *testing.T) { + ts := githubApiMock([]*mockResponse{ + { + ExpectedUri: "/repos/test/blah", + ResponseBody: `{ + "message": "internal server error" +}`, + StatusCode: 500, + }, + { + ExpectedUri: "/repos/test/blah", + ResponseBody: `{ + "message": "internal server error" +}`, + StatusCode: 500, + }, + }) + defer ts.Close() + httpClient := http.DefaultClient + httpClient.Transport = NewRetryTransport(http.DefaultTransport, + WithMaxRetries(1), + WithRetryDelay(10*time.Second), + ) + client := github.NewClient(httpClient) + u, _ := url.Parse(ts.URL + "/") + client.BaseURL = u + start := time.Now() + ctx := context.WithValue(context.Background(), ctxId, t.Name()) + _, _, _ = client.Repositories.Get(ctx, "test", "blah") + if time.Since(start) > time.Second { + t.Fatalf("Slept after last retry: %s", time.Since(start)) + } +} + type mockResponse struct { ExpectedUri string ExpectedMethod string