Skip to content

Commit 4cf3f47

Browse files
committed
Retry 5XX errors and rate limiting during polling
We default to an exponential backoff, but allow callers to override that as needed.
1 parent 9ca7685 commit 4cf3f47

File tree

2 files changed

+85
-4
lines changed

2 files changed

+85
-4
lines changed

lib/ReplicateClient.js

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ export default class ReplicateClient {
2525
async predict(
2626
version,
2727
input,
28-
{ onUpdate } = {},
29-
{ defaultPollingInterval = 500 } = {}
28+
{ onUpdate, onTemporaryError } = {},
29+
{
30+
defaultPollingInterval = 500,
31+
backoffFn = (errorCount) => Math.pow(2, errorCount) * 100,
32+
} = {}
3033
) {
3134
if (!version) {
3235
throw new ReplicateError("version is required");
@@ -41,14 +44,36 @@ export default class ReplicateClient {
4144
onUpdate && onUpdate(prediction);
4245

4346
let pollingInterval = defaultPollingInterval;
47+
let errorCount = 0;
4448

4549
while (!prediction.hasTerminalStatus()) {
4650
await sleep(pollingInterval);
4751
pollingInterval = defaultPollingInterval; // Reset to default each time.
4852

49-
prediction = await this.getPrediction(prediction.id);
53+
try {
54+
prediction = await this.getPrediction(prediction.id);
55+
56+
onUpdate && onUpdate(prediction);
57+
58+
errorCount = 0; // Reset because we've had a non-error response.
59+
} catch (err) {
60+
if (!err instanceof ReplicateResponseError) {
61+
throw err;
62+
}
5063

51-
onUpdate && onUpdate(prediction);
64+
if (
65+
!err.status ||
66+
(Math.floor(err.status / 100) !== 5 && err.status !== 429)
67+
) {
68+
throw err;
69+
}
70+
71+
errorCount += 1;
72+
73+
onTemporaryError && onTemporaryError(err);
74+
75+
pollingInterval = backoffFn(errorCount);
76+
}
5277
}
5378

5479
return prediction;

lib/ReplicateClient.test.js

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,62 @@ describe("predict()", () => {
150150
status: PredictionStatus.SUCCEEDED,
151151
});
152152
});
153+
154+
it("retries polling on error", async () => {
155+
const requestMockReturnValues = {
156+
"POST /v1/predictions": [
157+
() => ({
158+
id: "test-id",
159+
status: PredictionStatus.STARTING,
160+
}),
161+
],
162+
"GET /v1/predictions/test-id": [
163+
() => {
164+
throw new ReplicateResponseError(
165+
"test error",
166+
new Response("{}", {
167+
status: 500,
168+
statusText: "Internal Server Error",
169+
})
170+
);
171+
},
172+
() => {
173+
throw new ReplicateResponseError(
174+
"test error",
175+
new Response("{}", {
176+
status: 429,
177+
statusText: "Too Many Requests",
178+
})
179+
);
180+
},
181+
() => ({
182+
id: "test-id",
183+
status: PredictionStatus.SUCCEEDED,
184+
}),
185+
],
186+
};
187+
188+
jest
189+
.spyOn(client, "request")
190+
.mockImplementation((action) =>
191+
requestMockReturnValues[action].shift()()
192+
);
193+
const backoffFn = jest.fn(() => 0);
194+
195+
await client.predict(
196+
"test-version",
197+
{ text: "test text" },
198+
{},
199+
{ defaultPollingInterval: 0, backoffFn }
200+
);
201+
202+
expect(client.request).toHaveBeenCalledTimes(4);
203+
expect(client.request).toHaveLastReturnedWith({
204+
id: "test-id",
205+
status: PredictionStatus.SUCCEEDED,
206+
});
207+
expect(backoffFn).toHaveBeenCalledTimes(2);
208+
});
153209
});
154210

155211
describe("createPrediction()", () => {

0 commit comments

Comments
 (0)