Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ dev = [
"freezegun>=1.5.5",
"tzdata>=2025.2; sys_platform == 'win32'",
"opentelemetry-test-utils (>=0.59b0,<1)",
"polyfactory>=3.1.0",
]

[project.urls]
Expand Down Expand Up @@ -172,11 +173,14 @@ lint.ignore = [
"PLR0913", # Too many arguments for function call
"D106", # Missing docstring in public nested class
]
exclude = [".venv/"]
lint.mccabe = { max-complexity = 10 }
exclude = [".venv/"]
line-length = 88

[tool.ruff.lint.per-file-ignores]
"taskiq/compat.py" = [
"D103", # Missing docstring in public function
]
"tests/*" = [
"S101", # Use of assert detected
"S301", # Use of pickle detected
Expand Down
36 changes: 18 additions & 18 deletions taskiq/compat.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
# flake8: noqa
from collections.abc import Hashable
from functools import lru_cache
from importlib.metadata import version
from typing import Any, Dict, Hashable, Optional, Type, TypeVar, Union
from typing import Any, TypeVar

import pydantic
from packaging.version import Version, parse

PYDANTIC_VER = parse(version("pydantic"))

Model = TypeVar("Model", bound="pydantic.BaseModel")
IS_PYDANTIC2 = PYDANTIC_VER >= Version("2.0")
IS_PYDANTIC2 = Version("2.0") <= PYDANTIC_VER

if IS_PYDANTIC2:
T = TypeVar("T", bound=Hashable)

@lru_cache()
def create_type_adapter(annot: Type[T]) -> pydantic.TypeAdapter[T]:
@lru_cache
def create_type_adapter(annot: type[T]) -> pydantic.TypeAdapter[T]:
return pydantic.TypeAdapter(annot)

def parse_obj_as(annot: Type[T], obj: Any) -> T:
def parse_obj_as(annot: type[T], obj: Any) -> T:
return create_type_adapter(annot).validate_python(obj)

def model_validate(
model_class: Type[Model],
message: Dict[str, Any],
model_class: type[Model],
message: dict[str, Any],
) -> Model:
return model_class.model_validate(message)

def model_dump(instance: Model) -> Dict[str, Any]:
def model_dump(instance: Model) -> dict[str, Any]:
return instance.model_dump(mode="json")

def model_validate_json(
model_class: Type[Model],
message: Union[str, bytes, bytearray],
model_class: type[Model],
message: str | bytes | bytearray,
) -> Model:
return model_class.model_validate_json(message)

Expand All @@ -41,7 +41,7 @@ def model_dump_json(instance: Model) -> str:

def model_copy(
instance: Model,
update: Optional[Dict[str, Any]] = None,
update: dict[str, Any] | None = None,
deep: bool = False,
) -> Model:
return instance.model_copy(update=update, deep=deep)
Expand All @@ -52,17 +52,17 @@ def model_copy(
parse_obj_as = pydantic.parse_obj_as # type: ignore

def model_validate(
model_class: Type[Model],
message: Dict[str, Any],
model_class: type[Model],
message: dict[str, Any],
) -> Model:
return model_class.parse_obj(message)

def model_dump(instance: Model) -> Dict[str, Any]:
def model_dump(instance: Model) -> dict[str, Any]:
return instance.dict()

def model_validate_json(
model_class: Type[Model],
message: Union[str, bytes, bytearray],
model_class: type[Model],
message: str | bytes | bytearray,
) -> Model:
return model_class.parse_raw(message) # type: ignore[arg-type]

Expand All @@ -71,7 +71,7 @@ def model_dump_json(instance: Model) -> str:

def model_copy(
instance: Model,
update: Optional[Dict[str, Any]] = None,
update: dict[str, Any] | None = None,
deep: bool = False,
) -> Model:
return instance.copy(update=update, deep=deep)
Expand Down
18 changes: 11 additions & 7 deletions taskiq/middlewares/taskiq_admin_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import aiohttp

from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.compat import model_dump
from taskiq.message import TaskiqMessage
from taskiq.result import TaskiqResult

Expand Down Expand Up @@ -115,12 +116,13 @@ async def post_send(self, message: TaskiqMessage) -> None:

:param message: kicked message.
"""
dict_message: dict[str, Any] = model_dump(message)
await self._spawn_request(
f"/api/tasks/{message.task_id}/queued",
{
"args": message.args,
"kwargs": message.kwargs,
"labels": message.labels,
"args": dict_message["args"],
"kwargs": dict_message["kwargs"],
"labels": dict_message["labels"],
"queuedAt": self._now_iso(),
"taskName": message.task_name,
"worker": self.__ta_broker_name,
Expand All @@ -137,12 +139,13 @@ async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
:param message: incoming parsed taskiq message.
:return: modified message.
"""
dict_message: dict[str, Any] = model_dump(message)
await self._spawn_request(
f"/api/tasks/{message.task_id}/started",
{
"args": message.args,
"kwargs": message.kwargs,
"labels": message.labels,
"args": dict_message["args"],
"kwargs": dict_message["kwargs"],
"labels": dict_message["labels"],
"startedAt": self._now_iso(),
"taskName": message.task_name,
"worker": self.__ta_broker_name,
Expand All @@ -164,12 +167,13 @@ async def post_execute(
:param message: incoming message.
:param result: result of execution for current task.
"""
dict_result: dict[str, Any] = model_dump(result)
await self._spawn_request(
f"/api/tasks/{message.task_id}/executed",
{
"finishedAt": self._now_iso(),
"executionTime": result.execution_time,
"error": None if result.error is None else repr(result.error),
"returnValue": {"return_value": result.return_value},
"returnValue": {"return_value": dict_result["return_value"]},
},
)
69 changes: 69 additions & 0 deletions tests/middlewares/admin_middleware/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
from aiohttp import web
from aiohttp.test_utils import TestServer
from typing_extensions import AsyncGenerator

from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.brokers.shared_broker import async_shared_broker
from taskiq.middlewares import TaskiqAdminMiddleware
from tests.middlewares.admin_middleware.dto import (
DataclassDTO,
PydanticDTO,
TypedDictDTO,
)


@pytest.fixture
async def admin_api_server() -> AsyncGenerator[TestServer, None]:
async def handle_queued(request: web.Request) -> web.Response:
return web.json_response({"status": "ok"}, status=200)

async def handle_started(request: web.Request) -> web.Response:
return web.json_response({"status": "ok"}, status=200)

async def handle_executed(request: web.Request) -> web.Response:
return web.json_response({"status": "ok"}, status=200)

app = web.Application()
app.router.add_post("/api/tasks/{task_id}/queued", handle_queued)
app.router.add_post("/api/tasks/{task_id}/started", handle_started)
app.router.add_post("/api/tasks/{task_id}/executed", handle_executed)

server = TestServer(app)
await server.start_server()
yield server
await server.close()


@pytest.fixture
async def broker_with_admin_middleware(
admin_api_server: TestServer,
) -> AsyncGenerator[InMemoryBroker, None]:
broker = InMemoryBroker(await_inplace=True).with_middlewares(
TaskiqAdminMiddleware(
str(admin_api_server.make_url("/")), # URL тестового сервера
"supersecret",
taskiq_broker_name="InMemory",
),
)

broker.register_task(task_with_dataclass, task_name="task_with_dataclass")
broker.register_task(task_with_typed_dict, task_name="task_with_typed_dict")
broker.register_task(task_with_pydantic_model, task_name="task_with_pydantic_model")
async_shared_broker.default_broker(broker)

await broker.startup()
yield broker
await broker.shutdown()


async def task_with_dataclass(dto: DataclassDTO) -> None:
assert dto


async def task_with_typed_dict(dto: TypedDictDTO) -> None:
assert dto


async def task_with_pydantic_model(dto: PydanticDTO) -> None:
assert dto
36 changes: 36 additions & 0 deletions tests/middlewares/admin_middleware/dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from dataclasses import dataclass

import pydantic
from typing_extensions import TypedDict


@dataclass(frozen=True, slots=True)
class DataclassNestedDTO:
id: int
name: str


@dataclass(frozen=True, slots=True)
class DataclassDTO:
nested: DataclassNestedDTO
recipients: list[str]
subject: str
attachments: list[str] | None = None
text: str | None = None
html: str | None = None


class PydanticDTO(pydantic.BaseModel):
number: int
text: str
flag: bool
list: list[float]
dictionary: dict[str, str] | None = None


class TypedDictDTO(TypedDict):
id: int
name: str
active: bool
scores: list[int]
metadata: dict[str, str] | None
53 changes: 53 additions & 0 deletions tests/middlewares/admin_middleware/test_arguments_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from polyfactory.factories import BaseFactory, DataclassFactory, TypedDictFactory
from polyfactory.factories.pydantic_factory import ModelFactory

from taskiq.brokers.inmemory_broker import InMemoryBroker
from tests.middlewares.admin_middleware.dto import (
DataclassDTO,
PydanticDTO,
TypedDictDTO,
)


class DataclassDTOFactory(DataclassFactory[DataclassDTO]):
__model__ = DataclassDTO


class TypedDictDTOFactory(TypedDictFactory[TypedDictDTO]):
__model__ = TypedDictDTO


class PydanticDTOFactory(ModelFactory[PydanticDTO]):
__model__ = PydanticDTO


# @pytest.mark.skip
class TestArgumentsFormattingInAdminMiddleware:
@pytest.mark.parametrize(
"dto_factory, task_name",
[
pytest.param(DataclassDTOFactory, "task_with_dataclass", id="dataclass"),
pytest.param(TypedDictDTOFactory, "task_with_typed_dict", id="typeddict"),
pytest.param(PydanticDTOFactory, "task_with_pydantic_model", id="pydantic"),
],
)
async def test_when_task_dto_passed__then_middleware_successfully_send_request(
self,
broker_with_admin_middleware: InMemoryBroker,
dto_factory: type[BaseFactory], # type: ignore[type-arg]
task_name: str,
) -> None:
# given
task_arguments = dto_factory.build()
task = broker_with_admin_middleware.find_task(task_name)
assert task is not None, f"Task {task_name} should be registered in the broker"

# when
kicked_task = await task.kiq(task_arguments)
await broker_with_admin_middleware.wait_all()

# then
result = await kicked_task.get_result()
# we just expect no errors during post_send/pre_execute/post_execute
assert result.error is None
Loading