Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions taskbadger/procrastinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,22 @@ async def defer_async(**kwargs):
task.defer_async = defer_async


def _maybe_create_pending(task, kwargs):
"""Decide whether to track this defer, and if so create the TaskBadger
task and inject its id into ``kwargs``. Always returns the kwargs dict."""
def _create_pending_task(task, task_kwargs):
"""Create a PENDING TaskBadger task for ``task`` if it should be tracked.

Returns the created TaskBadger task, or ``None`` if Badger isn't
configured, the task isn't tracked (neither manual nor auto), or the
create call failed. ``task_kwargs`` is used only for the
``record_task_args`` data capture.
"""
if not Badger.is_configured():
return kwargs
return None

system = getattr(task, "_taskbadger_system", None)
manual = getattr(task, _MANUAL_ATTR, False)
auto = bool(system) and system.track_task(task.name)
if not manual and not auto:
return kwargs
return None

opts = dict(getattr(task, _OPTS_ATTR, {}) or {})
name = opts.pop("name", None) or task.name
Expand All @@ -185,12 +190,18 @@ def _maybe_create_pending(task, kwargs):
if record_args is None:
record_args = bool(system) and system.record_task_args
if record_args:
data["procrastinate_task_kwargs"] = _serialize_kwargs(kwargs)
data["procrastinate_task_kwargs"] = _serialize_kwargs(task_kwargs)

if data:
create_kwargs["data"] = data

tb_task = create_task_safe(name, **create_kwargs)
return create_task_safe(name, **create_kwargs)


def _maybe_create_pending(task, kwargs):
"""Decide whether to track this defer, and if so create the TaskBadger
task and inject its id into ``kwargs``. Always returns the kwargs dict."""
tb_task = _create_pending_task(task, kwargs)
if tb_task is None:
return kwargs

Expand Down Expand Up @@ -262,6 +273,37 @@ def current_task():
return safe_get_task(tb_id)


def _patch_job_manager(app, system):
"""Patch ``app.job_manager.defer_periodic_job`` so periodic tasks are tracked.

Procrastinate's ``PeriodicDeferrer`` enqueues jobs by calling
``job_manager.defer_periodic_job(job=..., ...)`` directly, bypassing
``task.defer``/``defer_async``. Without this hook, ``@app.periodic`` tasks
would never get a PENDING TaskBadger task created at enqueue time.

Idempotent: a second call updates the system reference but doesn't
re-wrap.
"""
jm = app.job_manager
if not getattr(jm, "_taskbadger_original_defer_periodic_job", None):
original = jm.defer_periodic_job
jm._taskbadger_original_defer_periodic_job = original

@functools.wraps(original)
async def patched(*, job, periodic_id, defer_timestamp):
task = app.tasks.get(job.task_name)
if task is not None:
tb_task = _create_pending_task(task, job.task_kwargs)
if tb_task is not None:
new_kwargs = {**job.task_kwargs, TB_TASK_ID_KWARG: tb_task.id}
job = job.evolve(task_kwargs=new_kwargs)
return await jm._taskbadger_original_defer_periodic_job(
job=job, periodic_id=periodic_id, defer_timestamp=defer_timestamp
)

jm.defer_periodic_job = patched


def _patch_app_task(app, system):
"""Replace ``app.task`` with a wrapper that instruments newly-registered
tasks under the supplied ``system``. Idempotent — a second call replaces
Expand Down
3 changes: 2 additions & 1 deletion taskbadger/systems/procrastinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from taskbadger._integrations import BaseSystemIntegration
from taskbadger.procrastinate import _instrument_task, _patch_app_task
from taskbadger.procrastinate import _instrument_task, _patch_app_task, _patch_job_manager


class ProcrastinateSystemIntegration(BaseSystemIntegration):
Expand Down Expand Up @@ -41,6 +41,7 @@ def __init__(
for task in list(app.tasks.values()):
_instrument_task(task, system=self)
_patch_app_task(app, system=self)
_patch_job_manager(app, system=self)

def track_task(self, task_name):
# Never auto-track Procrastinate's built-in housekeeping tasks
Expand Down
45 changes: 45 additions & 0 deletions tests/test_procrastinate_system_integration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from unittest import mock

import procrastinate
Expand Down Expand Up @@ -98,6 +99,50 @@ def late(a):
create.assert_called_once()


@pytest.mark.usefixtures("_bind_settings")
def test_periodic_defer_creates_pending(app):
"""Periodic tasks are deferred via ``app.job_manager.defer_periodic_job``,
which bypasses ``task.defer``/``defer_async`` entirely. The system
integration must hook this path too, otherwise periodic jobs are invisible
to TaskBadger."""

@app.task(name="periodic_target")
def periodic_target(timestamp):
return timestamp

ProcrastinateSystemIntegration(app=app, auto_track_tasks=True)

timestamp = 1700000000
job = periodic_target.configure(task_kwargs={"timestamp": timestamp}).job

tb = task_for_test()
with mock.patch("taskbadger.procrastinate.create_task_safe", return_value=tb) as create:
asyncio.run(app.job_manager.defer_periodic_job(job=job, periodic_id="every-min", defer_timestamp=timestamp))

create.assert_called_once()
jobs_stored = list(app.connector.jobs.values())
assert jobs_stored[0]["args"][TB_TASK_ID_KWARG] == tb.id


@pytest.mark.usefixtures("_bind_settings")
def test_periodic_defer_skips_excluded(app):
"""Excludes apply on the periodic path too."""

@app.task(name="myapp.cleanup.flush")
def flush(timestamp):
pass

ProcrastinateSystemIntegration(app=app, auto_track_tasks=True, excludes=[r"myapp\.cleanup\..*"])

timestamp = 1700000000
job = flush.configure(task_kwargs={"timestamp": timestamp}).job

with mock.patch("taskbadger.procrastinate.create_task_safe") as create:
asyncio.run(app.job_manager.defer_periodic_job(job=job, periodic_id="every-min", defer_timestamp=timestamp))

create.assert_not_called()


@pytest.mark.usefixtures("_bind_settings")
def test_track_plus_auto_track_no_double_wrap(app):
@track
Expand Down
Loading