diff --git a/pyproject.toml b/pyproject.toml index 2563319..fb6fa6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "durable-workflow" -version = "0.4.34" +version = "0.4.35" description = "Python SDK for the Durable Workflow server (language-neutral HTTP protocol)" readme = "README.md" requires-python = ">=3.10" diff --git a/src/durable_workflow/client.py b/src/durable_workflow/client.py index 642d05d..56e0222 100644 --- a/src/durable_workflow/client.py +++ b/src/durable_workflow/client.py @@ -3276,14 +3276,22 @@ async def poll_activity_task( Returns the task payload, or ``None`` on poll timeout. Worker-plane endpoint — typically used by :class:`~durable_workflow.Worker`. """ - body: dict[str, Any] = {"worker_id": worker_id, "task_queue": task_queue} - try: - data = await self._request( - "POST", "/worker/activity-tasks/poll", worker=True, json=body, timeout=timeout - ) - except httpx.TimeoutException: - return None - return (data or {}).get("task") + body: dict[str, Any] = { + "worker_id": worker_id, + "task_queue": task_queue, + "poll_request_id": f"activity-poll-{uuid.uuid4().hex}", + } + for _ in range(2): + try: + data = await self._request( + "POST", "/worker/activity-tasks/poll", worker=True, json=body, timeout=timeout + ) + except httpx.TimeoutException: + continue + + return (data or {}).get("task") + + return None async def complete_activity_task( self, diff --git a/tests/test_client.py b/tests/test_client.py index 0f19e03..5f66315 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2072,6 +2072,43 @@ async def test_poll_workflow_task_retries_once_with_same_poll_request_id_after_t assert first_body["task_queue"] == "queue-1" assert first_body["poll_request_id"] == second_body["poll_request_id"] + @pytest.mark.asyncio + async def test_poll_activity_task_sends_poll_request_id(self, client: Client) -> None: + response_task = {"task": {"task_id": "activity-task-123"}} + + with patch.object(client, "_request", new_callable=AsyncMock, return_value=response_task) as mock: + task = await client.poll_activity_task(worker_id="worker-1", task_queue="queue-1") + + assert task == response_task["task"] + request_body = mock.await_args.kwargs["json"] + assert request_body["worker_id"] == "worker-1" + assert request_body["task_queue"] == "queue-1" + assert isinstance(request_body["poll_request_id"], str) + assert request_body["poll_request_id"] != "" + + @pytest.mark.asyncio + async def test_poll_activity_task_retries_once_with_same_poll_request_id_after_timeout( + self, client: Client + ) -> None: + response_task = {"task": {"task_id": "activity-task-123"}} + + with patch.object( + client, + "_request", + new_callable=AsyncMock, + side_effect=[httpx.TimeoutException("timeout"), response_task], + ) as mock: + task = await client.poll_activity_task(worker_id="worker-1", task_queue="queue-1") + + assert task == response_task["task"] + assert mock.await_count == 2 + + first_body = mock.await_args_list[0].kwargs["json"] + second_body = mock.await_args_list[1].kwargs["json"] + assert first_body["worker_id"] == "worker-1" + assert first_body["task_queue"] == "queue-1" + assert first_body["poll_request_id"] == second_body["poll_request_id"] + @pytest.mark.asyncio async def test_complete_workflow_task_matches_polyglot_fixture(self, client: Client) -> None: fixture_path = Path(__file__).parent / "fixtures" / "control-plane" / "workflow-task-complete-parity.json"