Skip to content

Commit 87d5394

Browse files
authored
fix: TaskiqAdminMiddleware doesn't work with dataclasses (#554)
1 parent 69077ad commit 87d5394

File tree

8 files changed

+223
-127
lines changed

8 files changed

+223
-127
lines changed

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ dev = [
6868
"freezegun>=1.5.5",
6969
"tzdata>=2025.2; sys_platform == 'win32'",
7070
"opentelemetry-test-utils (>=0.59b0,<1)",
71+
"polyfactory>=3.1.0",
7172
]
7273

7374
[project.urls]
@@ -172,11 +173,14 @@ lint.ignore = [
172173
"PLR0913", # Too many arguments for function call
173174
"D106", # Missing docstring in public nested class
174175
]
175-
exclude = [".venv/"]
176176
lint.mccabe = { max-complexity = 10 }
177+
exclude = [".venv/"]
177178
line-length = 88
178179

179180
[tool.ruff.lint.per-file-ignores]
181+
"taskiq/compat.py" = [
182+
"D103", # Missing docstring in public function
183+
]
180184
"tests/*" = [
181185
"S101", # Use of assert detected
182186
"S301", # Use of pickle detected

taskiq/compat.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,38 @@
1-
# flake8: noqa
1+
from collections.abc import Hashable
22
from functools import lru_cache
33
from importlib.metadata import version
4-
from typing import Any, Dict, Hashable, Optional, Type, TypeVar, Union
4+
from typing import Any, TypeVar
55

66
import pydantic
77
from packaging.version import Version, parse
88

99
PYDANTIC_VER = parse(version("pydantic"))
1010

1111
Model = TypeVar("Model", bound="pydantic.BaseModel")
12-
IS_PYDANTIC2 = PYDANTIC_VER >= Version("2.0")
12+
IS_PYDANTIC2 = Version("2.0") <= PYDANTIC_VER
1313

1414
if IS_PYDANTIC2:
1515
T = TypeVar("T", bound=Hashable)
1616

17-
@lru_cache()
18-
def create_type_adapter(annot: Type[T]) -> pydantic.TypeAdapter[T]:
17+
@lru_cache
18+
def create_type_adapter(annot: type[T]) -> pydantic.TypeAdapter[T]:
1919
return pydantic.TypeAdapter(annot)
2020

21-
def parse_obj_as(annot: Type[T], obj: Any) -> T:
21+
def parse_obj_as(annot: type[T], obj: Any) -> T:
2222
return create_type_adapter(annot).validate_python(obj)
2323

2424
def model_validate(
25-
model_class: Type[Model],
26-
message: Dict[str, Any],
25+
model_class: type[Model],
26+
message: dict[str, Any],
2727
) -> Model:
2828
return model_class.model_validate(message)
2929

30-
def model_dump(instance: Model) -> Dict[str, Any]:
30+
def model_dump(instance: Model) -> dict[str, Any]:
3131
return instance.model_dump(mode="json")
3232

3333
def model_validate_json(
34-
model_class: Type[Model],
35-
message: Union[str, bytes, bytearray],
34+
model_class: type[Model],
35+
message: str | bytes | bytearray,
3636
) -> Model:
3737
return model_class.model_validate_json(message)
3838

@@ -41,7 +41,7 @@ def model_dump_json(instance: Model) -> str:
4141

4242
def model_copy(
4343
instance: Model,
44-
update: Optional[Dict[str, Any]] = None,
44+
update: dict[str, Any] | None = None,
4545
deep: bool = False,
4646
) -> Model:
4747
return instance.model_copy(update=update, deep=deep)
@@ -52,17 +52,17 @@ def model_copy(
5252
parse_obj_as = pydantic.parse_obj_as # type: ignore
5353

5454
def model_validate(
55-
model_class: Type[Model],
56-
message: Dict[str, Any],
55+
model_class: type[Model],
56+
message: dict[str, Any],
5757
) -> Model:
5858
return model_class.parse_obj(message)
5959

60-
def model_dump(instance: Model) -> Dict[str, Any]:
60+
def model_dump(instance: Model) -> dict[str, Any]:
6161
return instance.dict()
6262

6363
def model_validate_json(
64-
model_class: Type[Model],
65-
message: Union[str, bytes, bytearray],
64+
model_class: type[Model],
65+
message: str | bytes | bytearray,
6666
) -> Model:
6767
return model_class.parse_raw(message) # type: ignore[arg-type]
6868

@@ -71,7 +71,7 @@ def model_dump_json(instance: Model) -> str:
7171

7272
def model_copy(
7373
instance: Model,
74-
update: Optional[Dict[str, Any]] = None,
74+
update: dict[str, Any] | None = None,
7575
deep: bool = False,
7676
) -> Model:
7777
return instance.copy(update=update, deep=deep)

taskiq/middlewares/taskiq_admin_middleware.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import aiohttp
88

99
from taskiq.abc.middleware import TaskiqMiddleware
10+
from taskiq.compat import model_dump
1011
from taskiq.message import TaskiqMessage
1112
from taskiq.result import TaskiqResult
1213

@@ -115,12 +116,13 @@ async def post_send(self, message: TaskiqMessage) -> None:
115116
116117
:param message: kicked message.
117118
"""
119+
dict_message: dict[str, Any] = model_dump(message)
118120
await self._spawn_request(
119121
f"/api/tasks/{message.task_id}/queued",
120122
{
121-
"args": message.args,
122-
"kwargs": message.kwargs,
123-
"labels": message.labels,
123+
"args": dict_message["args"],
124+
"kwargs": dict_message["kwargs"],
125+
"labels": dict_message["labels"],
124126
"queuedAt": self._now_iso(),
125127
"taskName": message.task_name,
126128
"worker": self.__ta_broker_name,
@@ -137,12 +139,13 @@ async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
137139
:param message: incoming parsed taskiq message.
138140
:return: modified message.
139141
"""
142+
dict_message: dict[str, Any] = model_dump(message)
140143
await self._spawn_request(
141144
f"/api/tasks/{message.task_id}/started",
142145
{
143-
"args": message.args,
144-
"kwargs": message.kwargs,
145-
"labels": message.labels,
146+
"args": dict_message["args"],
147+
"kwargs": dict_message["kwargs"],
148+
"labels": dict_message["labels"],
146149
"startedAt": self._now_iso(),
147150
"taskName": message.task_name,
148151
"worker": self.__ta_broker_name,
@@ -164,12 +167,13 @@ async def post_execute(
164167
:param message: incoming message.
165168
:param result: result of execution for current task.
166169
"""
170+
dict_result: dict[str, Any] = model_dump(result)
167171
await self._spawn_request(
168172
f"/api/tasks/{message.task_id}/executed",
169173
{
170174
"finishedAt": self._now_iso(),
171175
"executionTime": result.execution_time,
172176
"error": None if result.error is None else repr(result.error),
173-
"returnValue": {"return_value": result.return_value},
177+
"returnValue": {"return_value": dict_result["return_value"]},
174178
},
175179
)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pytest
2+
from aiohttp import web
3+
from aiohttp.test_utils import TestServer
4+
from typing_extensions import AsyncGenerator
5+
6+
from taskiq.brokers.inmemory_broker import InMemoryBroker
7+
from taskiq.brokers.shared_broker import async_shared_broker
8+
from taskiq.middlewares import TaskiqAdminMiddleware
9+
from tests.middlewares.admin_middleware.dto import (
10+
DataclassDTO,
11+
PydanticDTO,
12+
TypedDictDTO,
13+
)
14+
15+
16+
@pytest.fixture
17+
async def admin_api_server() -> AsyncGenerator[TestServer, None]:
18+
async def handle_queued(request: web.Request) -> web.Response:
19+
return web.json_response({"status": "ok"}, status=200)
20+
21+
async def handle_started(request: web.Request) -> web.Response:
22+
return web.json_response({"status": "ok"}, status=200)
23+
24+
async def handle_executed(request: web.Request) -> web.Response:
25+
return web.json_response({"status": "ok"}, status=200)
26+
27+
app = web.Application()
28+
app.router.add_post("/api/tasks/{task_id}/queued", handle_queued)
29+
app.router.add_post("/api/tasks/{task_id}/started", handle_started)
30+
app.router.add_post("/api/tasks/{task_id}/executed", handle_executed)
31+
32+
server = TestServer(app)
33+
await server.start_server()
34+
yield server
35+
await server.close()
36+
37+
38+
@pytest.fixture
39+
async def broker_with_admin_middleware(
40+
admin_api_server: TestServer,
41+
) -> AsyncGenerator[InMemoryBroker, None]:
42+
broker = InMemoryBroker(await_inplace=True).with_middlewares(
43+
TaskiqAdminMiddleware(
44+
str(admin_api_server.make_url("/")), # URL тестового сервера
45+
"supersecret",
46+
taskiq_broker_name="InMemory",
47+
),
48+
)
49+
50+
broker.register_task(task_with_dataclass, task_name="task_with_dataclass")
51+
broker.register_task(task_with_typed_dict, task_name="task_with_typed_dict")
52+
broker.register_task(task_with_pydantic_model, task_name="task_with_pydantic_model")
53+
async_shared_broker.default_broker(broker)
54+
55+
await broker.startup()
56+
yield broker
57+
await broker.shutdown()
58+
59+
60+
async def task_with_dataclass(dto: DataclassDTO) -> None:
61+
assert dto
62+
63+
64+
async def task_with_typed_dict(dto: TypedDictDTO) -> None:
65+
assert dto
66+
67+
68+
async def task_with_pydantic_model(dto: PydanticDTO) -> None:
69+
assert dto
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from dataclasses import dataclass
2+
3+
import pydantic
4+
from typing_extensions import TypedDict
5+
6+
7+
@dataclass(frozen=True, slots=True)
8+
class DataclassNestedDTO:
9+
id: int
10+
name: str
11+
12+
13+
@dataclass(frozen=True, slots=True)
14+
class DataclassDTO:
15+
nested: DataclassNestedDTO
16+
recipients: list[str]
17+
subject: str
18+
attachments: list[str] | None = None
19+
text: str | None = None
20+
html: str | None = None
21+
22+
23+
class PydanticDTO(pydantic.BaseModel):
24+
number: int
25+
text: str
26+
flag: bool
27+
list: list[float]
28+
dictionary: dict[str, str] | None = None
29+
30+
31+
class TypedDictDTO(TypedDict):
32+
id: int
33+
name: str
34+
active: bool
35+
scores: list[int]
36+
metadata: dict[str, str] | None
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
from polyfactory.factories import BaseFactory, DataclassFactory, TypedDictFactory
3+
from polyfactory.factories.pydantic_factory import ModelFactory
4+
5+
from taskiq.brokers.inmemory_broker import InMemoryBroker
6+
from tests.middlewares.admin_middleware.dto import (
7+
DataclassDTO,
8+
PydanticDTO,
9+
TypedDictDTO,
10+
)
11+
12+
13+
class DataclassDTOFactory(DataclassFactory[DataclassDTO]):
14+
__model__ = DataclassDTO
15+
16+
17+
class TypedDictDTOFactory(TypedDictFactory[TypedDictDTO]):
18+
__model__ = TypedDictDTO
19+
20+
21+
class PydanticDTOFactory(ModelFactory[PydanticDTO]):
22+
__model__ = PydanticDTO
23+
24+
25+
# @pytest.mark.skip
26+
class TestArgumentsFormattingInAdminMiddleware:
27+
@pytest.mark.parametrize(
28+
"dto_factory, task_name",
29+
[
30+
pytest.param(DataclassDTOFactory, "task_with_dataclass", id="dataclass"),
31+
pytest.param(TypedDictDTOFactory, "task_with_typed_dict", id="typeddict"),
32+
pytest.param(PydanticDTOFactory, "task_with_pydantic_model", id="pydantic"),
33+
],
34+
)
35+
async def test_when_task_dto_passed__then_middleware_successfully_send_request(
36+
self,
37+
broker_with_admin_middleware: InMemoryBroker,
38+
dto_factory: type[BaseFactory], # type: ignore[type-arg]
39+
task_name: str,
40+
) -> None:
41+
# given
42+
task_arguments = dto_factory.build()
43+
task = broker_with_admin_middleware.find_task(task_name)
44+
assert task is not None, f"Task {task_name} should be registered in the broker"
45+
46+
# when
47+
kicked_task = await task.kiq(task_arguments)
48+
await broker_with_admin_middleware.wait_all()
49+
50+
# then
51+
result = await kicked_task.get_result()
52+
# we just expect no errors during post_send/pre_execute/post_execute
53+
assert result.error is None

0 commit comments

Comments
 (0)