From 9a3cde96635014d4e8ca86aa242ca57696612868 Mon Sep 17 00:00:00 2001 From: Durable Workflow Date: Fri, 15 May 2026 01:05:39 +0000 Subject: [PATCH] Retry workflow polls with stable request ids --- src/durable_workflow/client.py | 25 +++++++++++++++++-------- tests/test_client.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/durable_workflow/client.py b/src/durable_workflow/client.py index 873411c..642d05d 100644 --- a/src/durable_workflow/client.py +++ b/src/durable_workflow/client.py @@ -20,6 +20,7 @@ import asyncio import time +import uuid import warnings from collections.abc import AsyncIterator, Callable from dataclasses import dataclass @@ -3091,14 +3092,22 @@ async def poll_workflow_task( endpoint — most applications use :class:`~durable_workflow.Worker` rather than calling this directly. """ - body: dict[str, Any] = {"worker_id": worker_id, "task_queue": task_queue} - try: - data = await self._request( - "POST", "/worker/workflow-tasks/poll", worker=True, json=body, timeout=timeout - ) - except httpx.TimeoutException: - return None - return (data or {}).get("task") + body: dict[str, Any] = { + "worker_id": worker_id, + "task_queue": task_queue, + "poll_request_id": f"wf-poll-{uuid.uuid4().hex}", + } + for _ in range(2): + try: + data = await self._request( + "POST", "/worker/workflow-tasks/poll", worker=True, json=body, timeout=timeout + ) + except httpx.TimeoutException: + continue + + return (data or {}).get("task") + + return None async def complete_workflow_task( self, diff --git a/tests/test_client.py b/tests/test_client.py index fa721af..0f19e03 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2043,7 +2043,34 @@ async def test_poll_workflow_task_matches_polyglot_fixture(self, client: Client) assert task == fixture["response_body"]["task"] assert task["task_id"] == fixture["semantic_body"]["task_id"] assert mock.call_args.args[:2] == (fixture["request"]["method"], f"/api{fixture['request']['path']}") - assert mock.call_args.kwargs["json"] == fixture["request"]["body"] + request_body = mock.call_args.kwargs["json"] + assert request_body["worker_id"] == fixture["request"]["body"]["worker_id"] + assert request_body["task_queue"] == fixture["request"]["body"]["task_queue"] + assert isinstance(request_body["poll_request_id"], str) + assert request_body["poll_request_id"] != "" + + @pytest.mark.asyncio + async def test_poll_workflow_task_retries_once_with_same_poll_request_id_after_timeout( + self, client: Client + ) -> None: + response_task = {"task": {"task_id": "task-123"}} + + with patch.object( + client, + "_request", + new_callable=AsyncMock, + side_effect=[httpx.TimeoutException("timeout"), response_task], + ) as mock: + task = await client.poll_workflow_task(worker_id="worker-1", task_queue="queue-1") + + assert task == response_task["task"] + assert mock.await_count == 2 + + first_body = mock.await_args_list[0].kwargs["json"] + second_body = mock.await_args_list[1].kwargs["json"] + assert first_body["worker_id"] == "worker-1" + assert first_body["task_queue"] == "queue-1" + assert first_body["poll_request_id"] == second_body["poll_request_id"] @pytest.mark.asyncio async def test_complete_workflow_task_matches_polyglot_fixture(self, client: Client) -> None: