diff --git a/.github/workflows/java-ci.yml b/.github/workflows/java-ci.yml index 50d7741..5e2b535 100644 --- a/.github/workflows/java-ci.yml +++ b/.github/workflows/java-ci.yml @@ -33,8 +33,8 @@ jobs: --file duo-client/pom.xml - name: Test with Maven run: > - mvn test - --batch-mode + mvn verify + --batch-mode -file duo-client/pom.xml - name: Lint with checkstyle run: mvn checkstyle:check diff --git a/duo-client/pom.xml b/duo-client/pom.xml index c1a2c29..7957502 100644 --- a/duo-client/pom.xml +++ b/duo-client/pom.xml @@ -65,6 +65,18 @@ 3.12.4 test + + com.squareup.okhttp3 + mockwebserver + 4.12.0 + test + + + com.squareup.okhttp3 + okhttp-tls + 4.12.0 + test + @@ -123,6 +135,19 @@ methods + + org.apache.maven.plugins + maven-failsafe-plugin + 3.2.5 + + + + integration-test + verify + + + + org.cyclonedx cyclonedx-maven-plugin diff --git a/duo-client/src/main/java/com/duosecurity/client/Http.java b/duo-client/src/main/java/com/duosecurity/client/Http.java index fadb3fa..d897478 100644 --- a/duo-client/src/main/java/com/duosecurity/client/Http.java +++ b/duo-client/src/main/java/com/duosecurity/client/Http.java @@ -42,6 +42,7 @@ public class Http { private Headers.Builder headers; private SortedMap params = new TreeMap(); protected int sigVersion = 5; + private long maxBackoffMs = MAX_BACKOFF_MS; private Random random = new Random(); private OkHttpClient httpClient; private SortedMap additionalDuoHeaders = new TreeMap(); @@ -314,10 +315,14 @@ private Response executeRequest(Request request) throws Exception { long backoffMs = INITIAL_BACKOFF_MS; while (true) { Response response = httpClient.newCall(request).execute(); - if (response.code() != RATE_LIMIT_ERROR_CODE || backoffMs > MAX_BACKOFF_MS) { + if (response.code() != RATE_LIMIT_ERROR_CODE || backoffMs > maxBackoffMs) { return response; } + // Close the 429 response to release the connection back to the pool before retrying + if (response.body() != null) { + response.close(); + } sleep(backoffMs + nextRandomInt(1000)); backoffMs *= BACKOFF_FACTOR; } @@ -327,6 +332,13 @@ protected void sleep(long ms) throws Exception { Thread.sleep(ms); } + protected void setMaxBackoffMs(long maxBackoffMs) { + if (maxBackoffMs < 0) { + throw new IllegalArgumentException("maxBackoffMs must be >= 0"); + } + this.maxBackoffMs = maxBackoffMs; + } + public void signRequest(String ikey, String skey) throws UnsupportedEncodingException { signRequest(ikey, skey, sigVersion); @@ -529,6 +541,7 @@ protected abstract static class ClientBuilder { private final String uri; private int timeout = DEFAULT_TIMEOUT_SECS; + private long maxBackoffMs = MAX_BACKOFF_MS; private String[] caCerts = null; private SortedMap additionalDuoHeaders = new TreeMap(); private Map headers = new HashMap(); @@ -558,6 +571,32 @@ public ClientBuilder useTimeout(int timeout) { return this; } + /** + * Set the maximum base backoff time in milliseconds for rate limit (429) retries. + * When a request receives a 429 response, the client retries with exponential + * backoff until the base backoff exceeds this threshold. Note that actual sleep + * time includes up to 1000ms of random jitter on top of the base backoff. + * Setting to 0 disables retries (as does any value below the initial + * backoff of 1000ms). Default is 32000ms (32 seconds). + * + *

Note: When using method chaining from outside this package (e.g. with + * {@code AuthBuilder} or {@code AdminBuilder}), assign the builder to a variable + * and call methods separately, then call {@code build()}. This is a known + * limitation of all {@code ClientBuilder} methods. + * + * @param maxBackoffMs the maximum base backoff in milliseconds (must be >= 0) + * @return the Builder + * @throws IllegalArgumentException if maxBackoffMs is negative + */ + public ClientBuilder useMaxBackoffMs(long maxBackoffMs) { + if (maxBackoffMs < 0) { + throw new IllegalArgumentException("maxBackoffMs must be >= 0"); + } + this.maxBackoffMs = maxBackoffMs; + + return this; + } + /** * Provide custom CA certificates for certificate pinning. * @@ -604,6 +643,7 @@ public ClientBuilder addHeader(String name, String value) { */ public T build() { T duoClient = createClient(method, host, uri, timeout); + duoClient.setMaxBackoffMs(maxBackoffMs); if (caCerts != null) { duoClient.useCustomCertificates(caCerts); } diff --git a/duo-client/src/test/java/com/duosecurity/client/HttpRateLimitRetryIntegrationIT.java b/duo-client/src/test/java/com/duosecurity/client/HttpRateLimitRetryIntegrationIT.java new file mode 100644 index 0000000..a127b71 --- /dev/null +++ b/duo-client/src/test/java/com/duosecurity/client/HttpRateLimitRetryIntegrationIT.java @@ -0,0 +1,138 @@ +package com.duosecurity.client; + +import okhttp3.OkHttpClient; +import okhttp3.Response; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.tls.HandshakeCertificates; +import okhttp3.tls.HeldCertificate; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import java.lang.reflect.Field; + +import static org.junit.Assert.assertEquals; + +public class HttpRateLimitRetryIntegrationIT { + + private MockWebServer server; + private HandshakeCertificates clientCerts; + + @Before + public void setUp() throws Exception { + HeldCertificate serverCert = new HeldCertificate.Builder() + .addSubjectAlternativeName("localhost") + .build(); + + HandshakeCertificates serverCerts = new HandshakeCertificates.Builder() + .heldCertificate(serverCert) + .build(); + + clientCerts = new HandshakeCertificates.Builder() + .addTrustedCertificate(serverCert.certificate()) + .build(); + + server = new MockWebServer(); + server.useHttps(serverCerts.sslSocketFactory(), false); + server.start(); + } + + @After + public void tearDown() throws Exception { + server.shutdown(); + } + + /** + * Builds an Http spy pointing at the MockWebServer, with sleep() stubbed out to avoid real + * delays and the OkHttpClient replaced with one that trusts the test certificate. + * + *

The builder must be constructed with host "localhost" (no port) so that CertificatePinner + * accepts the pattern. This method then sets the real host (with port) and replaces the + * OkHttpClient via reflection before the spy is used. + */ + private Http buildSpyHttp(Http.ClientBuilder builder) throws Exception { + Http spy = Mockito.spy(builder.build()); + Mockito.doNothing().when(spy).sleep(Mockito.any(Long.class)); + + // Point the host at the MockWebServer port (CertificatePinner rejects host:port patterns, + // so the builder uses "localhost" and we fix it here after construction). + Field hostField = Http.class.getDeclaredField("host"); + hostField.setAccessible(true); + hostField.set(spy, "localhost:" + server.getPort()); + + // Replace the OkHttpClient with one configured to trust the test certificate + OkHttpClient testClient = new OkHttpClient.Builder() + .sslSocketFactory(clientCerts.sslSocketFactory(), clientCerts.trustManager()) + .build(); + + Field httpClientField = Http.class.getDeclaredField("httpClient"); + httpClientField.setAccessible(true); + httpClientField.set(spy, testClient); + + return spy; + } + + private Http.HttpBuilder defaultBuilder() { + // Use "localhost" without a port — CertificatePinner rejects host:port patterns. + // buildSpyHttp sets the real host (with port) via reflection after construction. + return new Http.HttpBuilder("GET", "localhost", "/foo/bar"); + } + + @Test + public void testSingleRateLimitRetry() throws Exception { + server.enqueue(new MockResponse().setResponseCode(429)); + server.enqueue(new MockResponse().setResponseCode(200)); + + Http http = buildSpyHttp(defaultBuilder()); + Response response = http.executeHttpRequest(); + + assertEquals(200, response.code()); + assertEquals(2, server.getRequestCount()); + Mockito.verify(http, Mockito.times(1)).sleep(Mockito.any(Long.class)); + } + + @Test + public void testRateLimitExhaustsDefaultMaxBackoff() throws Exception { + // Enqueue more responses than will ever be consumed + for (int i = 0; i < 10; i++) { + server.enqueue(new MockResponse().setResponseCode(429)); + } + + Http http = buildSpyHttp(defaultBuilder()); + Response response = http.executeHttpRequest(); + + assertEquals(429, response.code()); + // Default max backoff (32s): sleeps at 1s, 2s, 4s, 8s, 16s, 32s = 6 sleeps, 7 total requests + assertEquals(7, server.getRequestCount()); + Mockito.verify(http, Mockito.times(6)).sleep(Mockito.any(Long.class)); + } + + @Test + public void testCustomMaxBackoffLimitsRetries() throws Exception { + for (int i = 0; i < 10; i++) { + server.enqueue(new MockResponse().setResponseCode(429)); + } + + Http http = buildSpyHttp(defaultBuilder().useMaxBackoffMs(4000)); + Response response = http.executeHttpRequest(); + + assertEquals(429, response.code()); + // maxBackoff=4000: sleeps at 1s, 2s, 4s = 3 sleeps, 4 total requests (next would be 8s > 4s) + assertEquals(4, server.getRequestCount()); + Mockito.verify(http, Mockito.times(3)).sleep(Mockito.any(Long.class)); + } + + @Test + public void testMaxBackoffZeroDisablesRetry() throws Exception { + server.enqueue(new MockResponse().setResponseCode(429)); + + Http http = buildSpyHttp(defaultBuilder().useMaxBackoffMs(0)); + Response response = http.executeHttpRequest(); + + assertEquals(429, response.code()); + assertEquals(1, server.getRequestCount()); + Mockito.verify(http, Mockito.never()).sleep(Mockito.any(Long.class)); + } +} diff --git a/duo-client/src/test/java/com/duosecurity/client/HttpRateLimitRetryTest.java b/duo-client/src/test/java/com/duosecurity/client/HttpRateLimitRetryTest.java index 5955e2c..baf41fb 100644 --- a/duo-client/src/test/java/com/duosecurity/client/HttpRateLimitRetryTest.java +++ b/duo-client/src/test/java/com/duosecurity/client/HttpRateLimitRetryTest.java @@ -26,10 +26,8 @@ public class HttpRateLimitRetryTest { private final int RANDOM_INT = 234; - @Before - public void before() throws Exception { - http = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build(); - http = Mockito.spy(http); + private void setupHttp(Http client) throws Exception { + http = Mockito.spy(client); Field httpClientField = Http.class.getDeclaredField("httpClient"); httpClientField.setAccessible(true); @@ -39,6 +37,12 @@ public void before() throws Exception { Mockito.doNothing().when(http).sleep(Mockito.any(Long.class)); } + @Before + public void before() throws Exception { + Http client = new Http.HttpBuilder("GET", "example.test", "/foo/bar").build(); + setupHttp(client); + } + @Test public void testSingleRateLimitRetry() throws Exception { final List responses = new ArrayList(); @@ -128,4 +132,87 @@ public Call answer(InvocationOnMock invocationOnMock) throws Throwable { assertEquals(16000L + RANDOM_INT, (long) sleepTimes.get(4)); assertEquals(32000L + RANDOM_INT, (long) sleepTimes.get(5)); } + + @Test + public void testMaxBackoffZeroDisablesRetry() throws Exception { + Http customHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar") + .useMaxBackoffMs(0) + .build(); + setupHttp(customHttp); + + final List responses = new ArrayList(); + + Mockito.when(httpClient.newCall(Mockito.any(Request.class))).thenAnswer(new Answer() { + @Override + public Call answer(InvocationOnMock invocationOnMock) throws Throwable { + Call call = Mockito.mock(Call.class); + + Response resp = new Response.Builder() + .protocol(Protocol.HTTP_2) + .code(429) + .request((Request) invocationOnMock.getArguments()[0]) + .message("HTTP 429") + .build(); + responses.add(resp); + Mockito.when(call.execute()).thenReturn(resp); + + return call; + } + }); + + Response actualRes = http.executeHttpRequest(); + assertEquals(1, responses.size()); + assertEquals(429, actualRes.code()); + + // Verify no sleep was called + Mockito.verify(http, Mockito.never()).sleep(Mockito.any(Long.class)); + } + + @Test + public void testMaxBackoffCustomLimit() throws Exception { + Http customHttp = new Http.HttpBuilder("GET", "example.test", "/foo/bar") + .useMaxBackoffMs(4000) + .build(); + setupHttp(customHttp); + + final List responses = new ArrayList(); + + Mockito.when(httpClient.newCall(Mockito.any(Request.class))).thenAnswer(new Answer() { + @Override + public Call answer(InvocationOnMock invocationOnMock) throws Throwable { + Call call = Mockito.mock(Call.class); + + Response resp = new Response.Builder() + .protocol(Protocol.HTTP_2) + .code(429) + .request((Request) invocationOnMock.getArguments()[0]) + .message("HTTP 429") + .build(); + responses.add(resp); + Mockito.when(call.execute()).thenReturn(resp); + + return call; + } + }); + + // With maxBackoff=4000, retries at 1000, 2000, 4000, then 8000 > 4000 exits + // That's 4 total requests (1 initial + 3 retries) + Response actualRes = http.executeHttpRequest(); + assertEquals(4, responses.size()); + assertEquals(429, actualRes.code()); + + ArgumentCaptor sleepCapture = ArgumentCaptor.forClass(Long.class); + Mockito.verify(http, Mockito.times(3)).sleep(sleepCapture.capture()); + List sleepTimes = sleepCapture.getAllValues(); + assertEquals(1000L + RANDOM_INT, (long) sleepTimes.get(0)); + assertEquals(2000L + RANDOM_INT, (long) sleepTimes.get(1)); + assertEquals(4000L + RANDOM_INT, (long) sleepTimes.get(2)); + } + + @Test(expected = IllegalArgumentException.class) + public void testMaxBackoffNegativeThrows() { + new Http.HttpBuilder("GET", "example.test", "/foo/bar") + .useMaxBackoffMs(-1) + .build(); + } }