From 02f338a5bdf7fa846ec8a42988bb335c75099678 Mon Sep 17 00:00:00 2001 From: Anfimov Dima Date: Fri, 28 Nov 2025 11:06:48 +0100 Subject: [PATCH 1/3] chore: enable ruff for taskiq/compat.py --- pyproject.toml | 3 +++ taskiq/compat.py | 36 ++++++++++++++++++------------------ 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9cbae573..b923c135 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,6 +177,9 @@ lint.mccabe = { max-complexity = 10 } line-length = 88 [tool.ruff.lint.per-file-ignores] +"taskiq/compat.py" = [ + "D103", # Missing docstring in public function +] "tests/*" = [ "S101", # Use of assert detected "S301", # Use of pickle detected diff --git a/taskiq/compat.py b/taskiq/compat.py index f221f67d..5ff3d0ee 100644 --- a/taskiq/compat.py +++ b/taskiq/compat.py @@ -1,7 +1,7 @@ -# flake8: noqa +from collections.abc import Hashable from functools import lru_cache from importlib.metadata import version -from typing import Any, Dict, Hashable, Optional, Type, TypeVar, Union +from typing import Any, TypeVar import pydantic from packaging.version import Version, parse @@ -9,30 +9,30 @@ PYDANTIC_VER = parse(version("pydantic")) Model = TypeVar("Model", bound="pydantic.BaseModel") -IS_PYDANTIC2 = PYDANTIC_VER >= Version("2.0") +IS_PYDANTIC2 = Version("2.0") <= PYDANTIC_VER if IS_PYDANTIC2: T = TypeVar("T", bound=Hashable) - @lru_cache() - def create_type_adapter(annot: Type[T]) -> pydantic.TypeAdapter[T]: + @lru_cache + def create_type_adapter(annot: type[T]) -> pydantic.TypeAdapter[T]: return pydantic.TypeAdapter(annot) - def parse_obj_as(annot: Type[T], obj: Any) -> T: + def parse_obj_as(annot: type[T], obj: Any) -> T: return create_type_adapter(annot).validate_python(obj) def model_validate( - model_class: Type[Model], - message: Dict[str, Any], + model_class: type[Model], + message: dict[str, Any], ) -> Model: return model_class.model_validate(message) - def model_dump(instance: Model) -> Dict[str, Any]: + def model_dump(instance: Model) -> dict[str, Any]: return instance.model_dump(mode="json") def model_validate_json( - model_class: Type[Model], - message: Union[str, bytes, bytearray], + model_class: type[Model], + message: str | bytes | bytearray, ) -> Model: return model_class.model_validate_json(message) @@ -41,7 +41,7 @@ def model_dump_json(instance: Model) -> str: def model_copy( instance: Model, - update: Optional[Dict[str, Any]] = None, + update: dict[str, Any] | None = None, deep: bool = False, ) -> Model: return instance.model_copy(update=update, deep=deep) @@ -52,17 +52,17 @@ def model_copy( parse_obj_as = pydantic.parse_obj_as # type: ignore def model_validate( - model_class: Type[Model], - message: Dict[str, Any], + model_class: type[Model], + message: dict[str, Any], ) -> Model: return model_class.parse_obj(message) - def model_dump(instance: Model) -> Dict[str, Any]: + def model_dump(instance: Model) -> dict[str, Any]: return instance.dict() def model_validate_json( - model_class: Type[Model], - message: Union[str, bytes, bytearray], + model_class: type[Model], + message: str | bytes | bytearray, ) -> Model: return model_class.parse_raw(message) # type: ignore[arg-type] @@ -71,7 +71,7 @@ def model_dump_json(instance: Model) -> str: def model_copy( instance: Model, - update: Optional[Dict[str, Any]] = None, + update: dict[str, Any] | None = None, deep: bool = False, ) -> Model: return instance.copy(update=update, deep=deep) From 253624d00d63784e8d052bb7c091f8d6bd3e8839 Mon Sep 17 00:00:00 2001 From: Anfimov Dima Date: Fri, 28 Nov 2025 16:55:02 +0100 Subject: [PATCH 2/3] fix: TaskiqAdminMiddleware work with dataclasses --- pyproject.toml | 3 +- taskiq/middlewares/taskiq_admin_middleware.py | 18 ++-- .../middlewares/admin_middleware/conftest.py | 72 +++++++++++++ tests/middlewares/admin_middleware/dto.py | 36 +++++++ .../test_arguments_formatting.py | 52 +++++++++ .../test_taskiq_admin_middleware.py | 101 ------------------ uv.lock | 31 ++++++ 7 files changed, 204 insertions(+), 109 deletions(-) create mode 100644 tests/middlewares/admin_middleware/conftest.py create mode 100644 tests/middlewares/admin_middleware/dto.py create mode 100644 tests/middlewares/admin_middleware/test_arguments_formatting.py delete mode 100644 tests/middlewares/test_taskiq_admin_middleware.py diff --git a/pyproject.toml b/pyproject.toml index b923c135..47bd43da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ dev = [ "freezegun>=1.5.5", "tzdata>=2025.2; sys_platform == 'win32'", "opentelemetry-test-utils (>=0.59b0,<1)", + "polyfactory>=3.1.0", ] [project.urls] @@ -172,8 +173,8 @@ lint.ignore = [ "PLR0913", # Too many arguments for function call "D106", # Missing docstring in public nested class ] -exclude = [".venv/"] lint.mccabe = { max-complexity = 10 } +exclude = [".venv/"] line-length = 88 [tool.ruff.lint.per-file-ignores] diff --git a/taskiq/middlewares/taskiq_admin_middleware.py b/taskiq/middlewares/taskiq_admin_middleware.py index 855b6eef..ad128924 100644 --- a/taskiq/middlewares/taskiq_admin_middleware.py +++ b/taskiq/middlewares/taskiq_admin_middleware.py @@ -7,6 +7,7 @@ import aiohttp from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.compat import model_dump from taskiq.message import TaskiqMessage from taskiq.result import TaskiqResult @@ -115,12 +116,13 @@ async def post_send(self, message: TaskiqMessage) -> None: :param message: kicked message. """ + dict_message: dict[str, Any] = model_dump(message) await self._spawn_request( f"/api/tasks/{message.task_id}/queued", { - "args": message.args, - "kwargs": message.kwargs, - "labels": message.labels, + "args": dict_message["args"], + "kwargs": dict_message["kwargs"], + "labels": dict_message["labels"], "queuedAt": self._now_iso(), "taskName": message.task_name, "worker": self.__ta_broker_name, @@ -137,12 +139,13 @@ async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: :param message: incoming parsed taskiq message. :return: modified message. """ + dict_message: dict[str, Any] = model_dump(message) await self._spawn_request( f"/api/tasks/{message.task_id}/started", { - "args": message.args, - "kwargs": message.kwargs, - "labels": message.labels, + "args": dict_message["args"], + "kwargs": dict_message["kwargs"], + "labels": dict_message["labels"], "startedAt": self._now_iso(), "taskName": message.task_name, "worker": self.__ta_broker_name, @@ -164,12 +167,13 @@ async def post_execute( :param message: incoming message. :param result: result of execution for current task. """ + dict_result: dict[str, Any] = model_dump(result) await self._spawn_request( f"/api/tasks/{message.task_id}/executed", { "finishedAt": self._now_iso(), "executionTime": result.execution_time, "error": None if result.error is None else repr(result.error), - "returnValue": {"return_value": result.return_value}, + "returnValue": {"return_value": dict_result["return_value"]}, }, ) diff --git a/tests/middlewares/admin_middleware/conftest.py b/tests/middlewares/admin_middleware/conftest.py new file mode 100644 index 00000000..7288902d --- /dev/null +++ b/tests/middlewares/admin_middleware/conftest.py @@ -0,0 +1,72 @@ +import pytest +from aiohttp import web +from aiohttp.test_utils import TestServer +from typing_extensions import AsyncGenerator + +from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.brokers.shared_broker import async_shared_broker +from taskiq.middlewares import TaskiqAdminMiddleware +from tests.middlewares.admin_middleware.dto import ( + DataclassDTO, + PydanticDTO, + TypedDictDTO, +) + + +@pytest.fixture(scope="session") +async def admin_api_server() -> AsyncGenerator[TestServer, None]: + async def handle_queued(request: web.Request) -> web.Response: + return web.json_response({"status": "ok"}, status=200) + + async def handle_started(request: web.Request) -> web.Response: + return web.json_response({"status": "ok"}, status=200) + + async def handle_executed(request: web.Request) -> web.Response: + return web.json_response({"status": "ok"}, status=200) + + app = web.Application() + app.router.add_post("/api/tasks/{task_id}/queued", handle_queued) + app.router.add_post("/api/tasks/{task_id}/started", handle_started) + app.router.add_post("/api/tasks/{task_id}/executed", handle_executed) + + server = TestServer(app) + await server.start_server() + + yield server + + # Останавливаем сервер после теста + await server.close() + + +@pytest.fixture +async def broker_with_admin_middleware( + admin_api_server: TestServer, +) -> AsyncGenerator[InMemoryBroker, None]: + broker = InMemoryBroker().with_middlewares( + TaskiqAdminMiddleware( + str(admin_api_server.make_url("/")), # URL тестового сервера + "supersecret", + taskiq_broker_name="InMemory", + ), + ) + + broker.register_task(task_with_dataclass, task_name="task_with_dataclass") + broker.register_task(task_with_typed_dict, task_name="task_with_typed_dict") + broker.register_task(task_with_pydantic_model, task_name="task_with_pydantic_model") + async_shared_broker.default_broker(broker) + + await broker.startup() + yield broker + await broker.shutdown() + + +async def task_with_dataclass(dto: DataclassDTO) -> None: + assert dto + + +async def task_with_typed_dict(dto: TypedDictDTO) -> None: + assert dto + + +async def task_with_pydantic_model(dto: PydanticDTO) -> None: + assert dto diff --git a/tests/middlewares/admin_middleware/dto.py b/tests/middlewares/admin_middleware/dto.py new file mode 100644 index 00000000..03f99178 --- /dev/null +++ b/tests/middlewares/admin_middleware/dto.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import TypedDict + +import pydantic + + +@dataclass(frozen=True, slots=True) +class DataclassNestedDTO: + id: int + name: str + + +@dataclass(frozen=True, slots=True) +class DataclassDTO: + nested: DataclassNestedDTO + recipients: list[str] + subject: str + attachments: list[str] | None = None + text: str | None = None + html: str | None = None + + +class PydanticDTO(pydantic.BaseModel): + number: int + text: str + flag: bool + list: list[float] + dictionary: dict[str, str] | None = None + + +class TypedDictDTO(TypedDict): + id: int + name: str + active: bool + scores: list[int] + metadata: dict[str, str] | None diff --git a/tests/middlewares/admin_middleware/test_arguments_formatting.py b/tests/middlewares/admin_middleware/test_arguments_formatting.py new file mode 100644 index 00000000..0d021625 --- /dev/null +++ b/tests/middlewares/admin_middleware/test_arguments_formatting.py @@ -0,0 +1,52 @@ +import pytest +from polyfactory.factories import BaseFactory, DataclassFactory, TypedDictFactory +from polyfactory.factories.pydantic_factory import ModelFactory + +from taskiq.brokers.inmemory_broker import InMemoryBroker +from tests.middlewares.admin_middleware.dto import ( + DataclassDTO, + PydanticDTO, + TypedDictDTO, +) + + +class DataclassDTOFactory(DataclassFactory[DataclassDTO]): + __model__ = DataclassDTO + + +class TypedDictDTOFactory(TypedDictFactory[TypedDictDTO]): + __model__ = TypedDictDTO + + +class PydanticDTOFactory(ModelFactory[PydanticDTO]): + __model__ = PydanticDTO + + +class TestArgumentsFormattingInAdminMiddleware: + @pytest.mark.parametrize( + "dto_factory, task_name", + [ + pytest.param(DataclassDTOFactory, "task_with_dataclass", id="dataclass"), + pytest.param(TypedDictDTOFactory, "task_with_typed_dict", id="typeddict"), + pytest.param(PydanticDTOFactory, "task_with_pydantic_model", id="pydantic"), + ], + ) + async def test_when_task_dto_passed__then_middleware_successfully_send_request( + self, + broker_with_admin_middleware: InMemoryBroker, + dto_factory: type[BaseFactory], # type: ignore[type-arg] + task_name: str, + ) -> None: + # given + task_arguments = dto_factory.build() + task = broker_with_admin_middleware.find_task(task_name) + assert task is not None, f"Task {task_name} should be registered in the broker" + + # when + kicked_task = await task.kiq(task_arguments) + await broker_with_admin_middleware.wait_all() + + # then + result = await kicked_task.get_result() + # we just expect no errors during post_send/pre_execute/post_execute + assert result.error is None diff --git a/tests/middlewares/test_taskiq_admin_middleware.py b/tests/middlewares/test_taskiq_admin_middleware.py deleted file mode 100644 index 21771f24..00000000 --- a/tests/middlewares/test_taskiq_admin_middleware.py +++ /dev/null @@ -1,101 +0,0 @@ -import asyncio -import datetime -from collections.abc import AsyncGenerator -from unittest.mock import AsyncMock, Mock, patch - -import pytest - -from taskiq import TaskiqMessage -from taskiq.middlewares.taskiq_admin_middleware import TaskiqAdminMiddleware - - -@pytest.fixture -async def middleware() -> AsyncGenerator[TaskiqAdminMiddleware, None]: - middleware = TaskiqAdminMiddleware( - url="http://localhost:8000", - api_token="test-token", # noqa: S106 - timeout=5, - taskiq_broker_name="test-broker", - ) - await middleware.startup() - yield middleware - await middleware.shutdown() - - -@pytest.fixture -def message() -> TaskiqMessage: - return TaskiqMessage( - task_id="task-123", - task_name="test_task", - labels={ - "schedule": { - "cron": "*/1 * * * *", - "cron_offset": datetime.timedelta(hours=1), - "time": datetime.datetime.now(datetime.timezone.utc), - "labels": { - "test_bool": True, - "test_int": 1, - "test_str": "str", - "test_bytes": b"bytes", - }, - }, - }, - args=[1, 2, 3], - kwargs={"key": "value"}, - ) - - -def _make_mock_response() -> AsyncMock: - """Create a properly configured mock response object.""" - mock_response = AsyncMock() - mock_response.__aenter__.return_value = mock_response - mock_response.__aexit__.return_value = None - mock_response.ok = True - mock_response.raise_for_status = Mock() - return mock_response - - -class TestTaskiqAdminMiddlewarePostSend: - async def test_when_post_send_is_called__then_queued_endpoint_is_called( - self, - middleware: TaskiqAdminMiddleware, - message: TaskiqMessage, - ) -> None: - # Given - with patch("aiohttp.ClientSession.post") as mock_post: - mock_response = _make_mock_response() - mock_post.return_value = mock_response - - # When - await middleware.post_send(message) - await asyncio.sleep(0.1) - - # Then - mock_post.assert_called() - assert mock_post.call_args is not None - assert "/api/tasks/task-123/queued" in mock_post.call_args[0][0] - - async def test_when_post_send_is_called__then_payload_includes_task_info( - self, - middleware: TaskiqAdminMiddleware, - message: TaskiqMessage, - ) -> None: - # Given - with patch("aiohttp.ClientSession.post") as mock_post: - mock_response = _make_mock_response() - mock_post.return_value = mock_response - - # When - await middleware.post_send(message) - await asyncio.sleep(0.1) - - # Then - call_args = mock_post.call_args - assert call_args is not None - payload = call_args[1]["json"] - assert payload["args"] == message.args - assert payload["kwargs"] == message.kwargs - assert payload["taskName"] == message.task_name - assert payload["worker"] == "test-broker" - assert payload["labels"] == message.labels - assert "queuedAt" in payload diff --git a/uv.lock b/uv.lock index 1b2969fc..2828eb5b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,10 @@ version = 1 revision = 3 requires-python = ">=3.10, <4" +resolution-markers = [ + "platform_python_implementation != 'PyPy'", + "platform_python_implementation == 'PyPy'", +] [[package]] name = "aiohappyeyeballs" @@ -550,6 +554,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, ] +[[package]] +name = "faker" +version = "38.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/27/022d4dbd4c20567b4c294f79a133cc2f05240ea61e0d515ead18c995c249/faker-38.2.0.tar.gz", hash = "sha256:20672803db9c7cb97f9b56c18c54b915b6f1d8991f63d1d673642dc43f5ce7ab", size = 1941469, upload-time = "2025-11-19T16:37:31.892Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/93/00c94d45f55c336434a15f98d906387e87ce28f9918e4444829a8fda432d/faker-38.2.0-py3-none-any.whl", hash = "sha256:35fe4a0a79dee0dc4103a6083ee9224941e7d3594811a50e3969e547b0d2ee65", size = 1980505, upload-time = "2025-11-19T16:37:30.208Z" }, +] + [[package]] name = "filelock" version = "3.20.0" @@ -1194,6 +1210,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "polyfactory" +version = "3.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "faker" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/3a/db522ea17e0e8d38f3128889b5b600b3a1d5728ae0724f43a0ed5ed1f82e/polyfactory-3.1.0.tar.gz", hash = "sha256:9061c0a282e0594502576455230fce534f2915042be77715256c1e6bbbf24ac5", size = 344189, upload-time = "2025-11-25T08:10:16.555Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/7c/535646d75a1c510065169ea65693613c7a6bc64491bea13e7dad4f028ff3/polyfactory-3.1.0-py3-none-any.whl", hash = "sha256:78171232342c25906d542513c9f00ebf41eadec2c67b498490a577024dd7e867", size = 61836, upload-time = "2025-11-25T08:10:14.893Z" }, +] + [[package]] name = "pre-commit" version = "4.4.0" @@ -1835,6 +1864,7 @@ dev = [ { name = "freezegun" }, { name = "mypy" }, { name = "opentelemetry-test-utils" }, + { name = "polyfactory" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -1875,6 +1905,7 @@ dev = [ { name = "freezegun", specifier = ">=1.5.5" }, { name = "mypy", specifier = ">=1.18.2" }, { name = "opentelemetry-test-utils", specifier = ">=0.59b0,<1" }, + { name = "polyfactory", specifier = ">=3.1.0" }, { name = "pre-commit", specifier = ">=4.4.0" }, { name = "pytest", specifier = ">=9.0.1" }, { name = "pytest-cov", specifier = ">=7.0.0" }, From 8c053fe261bceece690f23352f4c92dc94aeee33 Mon Sep 17 00:00:00 2001 From: Anfimov Dima Date: Fri, 28 Nov 2025 17:09:09 +0100 Subject: [PATCH 3/3] fix: try to use TypedDict from typing_extension (because of pydantic issue) --- tests/middlewares/admin_middleware/conftest.py | 7 ++----- tests/middlewares/admin_middleware/dto.py | 2 +- .../admin_middleware/test_arguments_formatting.py | 1 + 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/middlewares/admin_middleware/conftest.py b/tests/middlewares/admin_middleware/conftest.py index 7288902d..dbec77a4 100644 --- a/tests/middlewares/admin_middleware/conftest.py +++ b/tests/middlewares/admin_middleware/conftest.py @@ -13,7 +13,7 @@ ) -@pytest.fixture(scope="session") +@pytest.fixture async def admin_api_server() -> AsyncGenerator[TestServer, None]: async def handle_queued(request: web.Request) -> web.Response: return web.json_response({"status": "ok"}, status=200) @@ -31,10 +31,7 @@ async def handle_executed(request: web.Request) -> web.Response: server = TestServer(app) await server.start_server() - yield server - - # Останавливаем сервер после теста await server.close() @@ -42,7 +39,7 @@ async def handle_executed(request: web.Request) -> web.Response: async def broker_with_admin_middleware( admin_api_server: TestServer, ) -> AsyncGenerator[InMemoryBroker, None]: - broker = InMemoryBroker().with_middlewares( + broker = InMemoryBroker(await_inplace=True).with_middlewares( TaskiqAdminMiddleware( str(admin_api_server.make_url("/")), # URL тестового сервера "supersecret", diff --git a/tests/middlewares/admin_middleware/dto.py b/tests/middlewares/admin_middleware/dto.py index 03f99178..20ce9170 100644 --- a/tests/middlewares/admin_middleware/dto.py +++ b/tests/middlewares/admin_middleware/dto.py @@ -1,7 +1,7 @@ from dataclasses import dataclass -from typing import TypedDict import pydantic +from typing_extensions import TypedDict @dataclass(frozen=True, slots=True) diff --git a/tests/middlewares/admin_middleware/test_arguments_formatting.py b/tests/middlewares/admin_middleware/test_arguments_formatting.py index 0d021625..a65b5669 100644 --- a/tests/middlewares/admin_middleware/test_arguments_formatting.py +++ b/tests/middlewares/admin_middleware/test_arguments_formatting.py @@ -22,6 +22,7 @@ class PydanticDTOFactory(ModelFactory[PydanticDTO]): __model__ = PydanticDTO +# @pytest.mark.skip class TestArgumentsFormattingInAdminMiddleware: @pytest.mark.parametrize( "dto_factory, task_name",