Skip to content

Commit 2c52537

Browse files
committed
refactor: share task cache instance between celery and procrastinate
1 parent a8fe0fa commit 2c52537

6 files changed

Lines changed: 33 additions & 49 deletions

File tree

taskbadger/_integrations.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Shared internals for taskbadger's optional system integrations
22
(Celery, Procrastinate). Not part of the public API.
33
4-
Each integration creates its own module-level ``TaskCache`` instance and
5-
defines a thin ``safe_get_task`` wrapper around the shared one defined here.
6-
``BaseSystemIntegration`` provides the common ctor/include-exclude shape;
7-
subclasses override ``track_task`` if they need to filter additional
4+
A single module-level ``TaskCache`` (``task_cache``) is shared across all
5+
integrations; task ids are UUIDs so cross-integration key collisions are not
6+
a concern. ``BaseSystemIntegration`` provides the common ctor/include-exclude
7+
shape; subclasses override ``track_task`` if they need to filter additional
88
task names (e.g. Procrastinate built-ins).
99
"""
1010

@@ -51,20 +51,23 @@ def unset(self, key) -> None:
5151
self.cache.pop(key, None)
5252

5353

54-
def safe_get_task(cache: TaskCache, task_id: str):
54+
task_cache = TaskCache()
55+
56+
57+
def safe_get_task(task_id: str):
5558
"""Cache-aware ``get_task``: returns the cached entry if present, otherwise
56-
fetches via the SDK ``get_task`` and caches the result. Errors are logged and
57-
swallowed (returns ``None``). ``None`` results are not cached.
59+
fetches from the API and caches the result. Errors are logged and swallowed
60+
(returns ``None``). ``None`` results are not cached.
5861
"""
59-
cached = cache.get(task_id)
62+
cached = task_cache.get(task_id)
6063
if cached is not None:
6164
return cached
6265
try:
6366
task = sdk.get_task(task_id)
6467
except Exception as e:
6568
log.warning("Error fetching task '%s': %s", task_id, e)
6669
return None
67-
cache.set(task_id, task)
70+
task_cache.set(task_id, task)
6871
return task
6972

7073

taskbadger/celery.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
from kombu import serialization
1414

1515
from . import sdk
16-
from ._integrations import TERMINAL_STATES, TaskCache
17-
from ._integrations import safe_get_task as _shared_safe_get_task
16+
from ._integrations import TERMINAL_STATES, safe_get_task, task_cache
1817
from .internal.models import StatusEnum
1918
from .mug import Badger
2019
from .safe_sdk import create_task_safe, update_task_safe
@@ -27,8 +26,6 @@
2726

2827
log = logging.getLogger("taskbadger")
2928

30-
_task_cache = TaskCache()
31-
3229

3330
class Task(celery.Task):
3431
"""A Celery Task that tracks itself with TaskBadger.
@@ -249,7 +246,7 @@ def _maybe_create_task(signal_sender):
249246
if task:
250247
# Store the task ID in the request so _update_task can find it
251248
signal_sender.request.update({TB_TASK_ID: task.id})
252-
_task_cache.set(task.id, task)
249+
task_cache.set(task.id, task)
253250

254251

255252
@task_prerun.connect
@@ -301,7 +298,7 @@ def _update_task(signal_sender, status, einfo=None):
301298
data = DefaultMergeStrategy().merge(task.data, {"exception": str(einfo)})
302299
task = update_task_safe(task.id, status=status, data=data)
303300
if task:
304-
_task_cache.set(task_id, task)
301+
task_cache.set(task_id, task)
305302

306303

307304
def enter_session():
@@ -321,17 +318,13 @@ def exit_session(signal_sender):
321318
if not task_id or not Badger.is_configured():
322319
return
323320

324-
_task_cache.unset(task_id)
321+
task_cache.unset(task_id)
325322

326323
session = Badger.current.session()
327324
if session.client:
328325
session.__exit__()
329326

330327

331-
def safe_get_task(task_id: str):
332-
return _shared_safe_get_task(_task_cache, task_id)
333-
334-
335328
def _get_taskbadger_task_id(request):
336329
if not request:
337330
return

taskbadger/procrastinate.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
import logging
1818
from contextvars import ContextVar
1919

20-
from ._integrations import TERMINAL_STATES, TaskCache
21-
from ._integrations import safe_get_task as _shared_safe_get_task
20+
from ._integrations import TERMINAL_STATES, safe_get_task, task_cache
2221
from .internal.models import StatusEnum
2322
from .mug import Badger
2423
from .safe_sdk import create_task_safe, update_task_safe
@@ -118,8 +117,8 @@ def _update_status(tb_id, status, exception=None):
118117
# Bypass the cache for the terminal-state check: the user may have
119118
# updated the task to a terminal state via the regular SDK during
120119
# the body, which wouldn't be reflected in our local cache.
121-
_task_cache.unset(tb_id)
122-
current = _safe_get_task(tb_id)
120+
task_cache.unset(tb_id)
121+
current = safe_get_task(tb_id)
123122
if current is not None and current.status in TERMINAL_STATES:
124123
return
125124
data = None
@@ -134,14 +133,7 @@ def _update_status(tb_id, status, exception=None):
134133
updated = update_task_safe(tb_id, status=status)
135134

136135
if updated is not None:
137-
_task_cache.set(tb_id, updated)
138-
139-
140-
_task_cache = TaskCache()
141-
142-
143-
def _safe_get_task(task_id):
144-
return _shared_safe_get_task(_task_cache, task_id)
136+
task_cache.set(tb_id, updated)
145137

146138

147139
def _wrap_defer(task):
@@ -267,7 +259,7 @@ def current_task():
267259
tb_id = _current_tb_task_id.get()
268260
if tb_id is None:
269261
return None
270-
return _safe_get_task(tb_id)
262+
return safe_get_task(tb_id)
271263

272264

273265
def _patch_app_task(app, system):

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
import pytest
22

3+
from taskbadger._integrations import task_cache
34
from taskbadger.mug import Badger, Settings
45

56

7+
@pytest.fixture(autouse=True)
8+
def _clear_task_cache():
9+
"""Clear the shared integrations task cache around every test so cached
10+
entries from earlier tests can't leak into later ones."""
11+
task_cache.cache.clear()
12+
yield
13+
task_cache.cache.clear()
14+
15+
616
@pytest.fixture
717
def _bind_settings():
818
Badger.current.bind(Settings("https://taskbadger.net", "token", "org", "proj"))

tests/test_procrastinate.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from procrastinate import testing
88

99
from taskbadger import StatusEnum
10-
from taskbadger.procrastinate import TB_TASK_ID_KWARG, _instrument_task, _task_cache, current_task, track
10+
from taskbadger.procrastinate import TB_TASK_ID_KWARG, _instrument_task, current_task, track
1111
from tests.utils import task_for_test
1212

1313

@@ -19,13 +19,6 @@ def _check_log_errors(caplog):
1919
pytest.fail(f"log errors during tests: {errors}")
2020

2121

22-
@pytest.fixture(autouse=True)
23-
def _clear_task_cache():
24-
_task_cache.cache.clear()
25-
yield
26-
_task_cache.cache.clear()
27-
28-
2922
@pytest.fixture
3023
def app():
3124
in_memory = testing.InMemoryConnector()

tests/test_procrastinate_system_integration.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,11 @@
44
import pytest
55
from procrastinate import testing
66

7-
from taskbadger.procrastinate import _INSTRUMENTED_ATTR, TB_TASK_ID_KWARG, _task_cache, track
7+
from taskbadger.procrastinate import _INSTRUMENTED_ATTR, TB_TASK_ID_KWARG, track
88
from taskbadger.systems.procrastinate import ProcrastinateSystemIntegration
99
from tests.utils import task_for_test
1010

1111

12-
@pytest.fixture(autouse=True)
13-
def _clear_task_cache():
14-
_task_cache.cache.clear()
15-
yield
16-
_task_cache.cache.clear()
17-
18-
1912
@pytest.fixture
2013
def app():
2114
in_memory = testing.InMemoryConnector()

0 commit comments

Comments
 (0)