Skip to content

Commit a8fe0fa

Browse files
committed
refactor: drop get_task_fn parameter from shared safe_get_task
1 parent 024cc9c commit a8fe0fa

7 files changed

Lines changed: 41 additions & 47 deletions

File tree

taskbadger/_integrations.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
(Celery, Procrastinate). Not part of the public API.
33
44
Each integration creates its own module-level ``TaskCache`` instance and
5-
defines a thin ``safe_get_task`` wrapper that reads ``get_task`` from the
6-
integration module's own globals (so existing test mocks on
7-
``taskbadger.celery.get_task`` / ``taskbadger.procrastinate.get_task`` keep
8-
working). ``BaseSystemIntegration`` provides the common ctor/include-exclude
9-
shape; subclasses override ``track_task`` if they need to filter additional
5+
defines a thin ``safe_get_task`` wrapper around the shared one defined here.
6+
``BaseSystemIntegration`` provides the common ctor/include-exclude shape;
7+
subclasses override ``track_task`` if they need to filter additional
108
task names (e.g. Procrastinate built-ins).
119
"""
1210

@@ -15,8 +13,8 @@
1513
import collections
1614
import logging
1715
import re
18-
from collections.abc import Callable
1916

17+
from . import sdk
2018
from .internal.models import StatusEnum
2119
from .systems import System
2220

@@ -53,21 +51,16 @@ def unset(self, key) -> None:
5351
self.cache.pop(key, None)
5452

5553

56-
def safe_get_task(cache: TaskCache, task_id: str, get_task_fn: Callable):
54+
def safe_get_task(cache: TaskCache, task_id: str):
5755
"""Cache-aware ``get_task``: returns the cached entry if present, otherwise
58-
fetches via ``get_task_fn`` and caches the result. Errors are logged and
56+
fetches via the SDK ``get_task`` and caches the result. Errors are logged and
5957
swallowed (returns ``None``). ``None`` results are not cached.
60-
61-
``get_task_fn`` is passed in (rather than imported here) so callers can
62-
use their own module-level ``get_task`` reference — this keeps existing
63-
test patches on ``taskbadger.celery.get_task`` / ``taskbadger.procrastinate.get_task``
64-
intercepting the fetch.
6558
"""
6659
cached = cache.get(task_id)
6760
if cached is not None:
6861
return cached
6962
try:
70-
task = get_task_fn(task_id)
63+
task = sdk.get_task(task_id)
7164
except Exception as e:
7265
log.warning("Error fetching task '%s': %s", task_id, e)
7366
return None

taskbadger/celery.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
)
1313
from kombu import serialization
1414

15+
from . import sdk
1516
from ._integrations import TERMINAL_STATES, TaskCache
1617
from ._integrations import safe_get_task as _shared_safe_get_task
1718
from .internal.models import StatusEnum
1819
from .mug import Badger
1920
from .safe_sdk import create_task_safe, update_task_safe
20-
from .sdk import DefaultMergeStrategy, get_task
21+
from .sdk import DefaultMergeStrategy
2122

2223
KWARG_PREFIX = "taskbadger_"
2324
TB_KWARGS_ARG = f"{KWARG_PREFIX}kwargs"
@@ -91,7 +92,7 @@ def apply_async(self, *args, **kwargs):
9192
tb_task_id = info.get(TB_TASK_ID) if isinstance(info, dict) else None
9293
setattr(result, TB_TASK_ID, tb_task_id)
9394

94-
_get_task = functools.partial(get_task, tb_task_id) if tb_task_id else lambda: None
95+
_get_task = functools.partial(sdk.get_task, tb_task_id) if tb_task_id else lambda: None
9596
setattr(result, "get_taskbadger_task", _get_task)
9697

9798
return result
@@ -328,7 +329,7 @@ def exit_session(signal_sender):
328329

329330

330331
def safe_get_task(task_id: str):
331-
return _shared_safe_get_task(_task_cache, task_id, get_task)
332+
return _shared_safe_get_task(_task_cache, task_id)
332333

333334

334335
def _get_taskbadger_task_id(request):

taskbadger/procrastinate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .internal.models import StatusEnum
2323
from .mug import Badger
2424
from .safe_sdk import create_task_safe, update_task_safe
25-
from .sdk import DefaultMergeStrategy, get_task
25+
from .sdk import DefaultMergeStrategy
2626

2727
log = logging.getLogger("taskbadger")
2828

@@ -141,7 +141,7 @@ def _update_status(tb_id, status, exception=None):
141141

142142

143143
def _safe_get_task(task_id):
144-
return _shared_safe_get_task(_task_cache, task_id, get_task)
144+
return _shared_safe_get_task(_task_cache, task_id)
145145

146146

147147
def _wrap_defer(task):

tests/test_celery.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def add_normal(self, a, b):
4343
with (
4444
mock.patch("taskbadger.celery.create_task_safe") as create,
4545
mock.patch("taskbadger.celery.update_task_safe") as update,
46-
mock.patch("taskbadger.celery.get_task") as get_task,
46+
mock.patch("taskbadger.sdk.get_task") as get_task,
4747
):
4848
tb_task = task_for_test()
4949
create.return_value = tb_task
@@ -71,7 +71,7 @@ def add_with_task_args(self, a, b):
7171
with (
7272
mock.patch("taskbadger.celery.create_task_safe") as create,
7373
mock.patch("taskbadger.celery.update_task_safe"),
74-
mock.patch("taskbadger.celery.get_task"),
74+
mock.patch("taskbadger.sdk.get_task"),
7575
):
7676
create.return_value = task_for_test()
7777

@@ -100,7 +100,7 @@ def add_with_task_args(self, a, b):
100100
with (
101101
mock.patch("taskbadger.celery.create_task_safe") as create,
102102
mock.patch("taskbadger.celery.update_task_safe"),
103-
mock.patch("taskbadger.celery.get_task"),
103+
mock.patch("taskbadger.sdk.get_task"),
104104
):
105105
create.return_value = task_for_test()
106106

@@ -129,7 +129,7 @@ def add_with_task_args(self, a, b):
129129
with (
130130
mock.patch("taskbadger.celery.create_task_safe") as create,
131131
mock.patch("taskbadger.celery.update_task_safe"),
132-
mock.patch("taskbadger.celery.get_task"),
132+
mock.patch("taskbadger.sdk.get_task"),
133133
):
134134
create.return_value = task_for_test()
135135

@@ -162,7 +162,7 @@ def add_with_task_kwargs(self, a, b, c=0):
162162
with (
163163
mock.patch("taskbadger.celery.create_task_safe") as create,
164164
mock.patch("taskbadger.celery.update_task_safe"),
165-
mock.patch("taskbadger.celery.get_task"),
165+
mock.patch("taskbadger.sdk.get_task"),
166166
):
167167
create.return_value = task_for_test()
168168

@@ -206,7 +206,7 @@ def add_task_custom_serialization(self, a):
206206
with (
207207
mock.patch("taskbadger.celery.create_task_safe") as create,
208208
mock.patch("taskbadger.celery.update_task_safe"),
209-
mock.patch("taskbadger.celery.get_task"),
209+
mock.patch("taskbadger.sdk.get_task"),
210210
):
211211
create.return_value = task_for_test()
212212

@@ -240,7 +240,7 @@ def add_with_task_args_in_decorator(self, a, b):
240240
with (
241241
mock.patch("taskbadger.celery.create_task_safe") as create,
242242
mock.patch("taskbadger.celery.update_task_safe"),
243-
mock.patch("taskbadger.celery.get_task"),
243+
mock.patch("taskbadger.sdk.get_task"),
244244
):
245245
create.return_value = task_for_test()
246246

@@ -272,7 +272,7 @@ def add_retry(self, a, b, is_retry=False):
272272
with (
273273
mock.patch("taskbadger.celery.create_task_safe") as create,
274274
mock.patch("taskbadger.celery.update_task_safe") as update,
275-
mock.patch("taskbadger.celery.get_task") as get_task,
275+
mock.patch("taskbadger.sdk.get_task") as get_task,
276276
):
277277
create.return_value = task_for_test()
278278
get_task.return_value = task_for_test()
@@ -292,7 +292,7 @@ def add_retry(self, a, b, is_retry=False):
292292
def test_celery_task_badger_not_configured(celery_session_app, celery_session_worker):
293293
@celery_session_app.task(bind=True, base=Task)
294294
def add_no_tb(self, a, b):
295-
with mock.patch("taskbadger.celery.get_task") as get_task:
295+
with mock.patch("taskbadger.sdk.get_task") as get_task:
296296
assert self.taskbadger_task_id is None
297297
assert Badger.current.session().client is None
298298
get_task.assert_not_called()
@@ -323,7 +323,7 @@ def add_no_tb(self, a, b):
323323
def test_task_direct_call(celery_session_app, celery_session_worker):
324324
@celery_session_app.task(bind=True, base=Task)
325325
def add_direct(self, a, b):
326-
with mock.patch("taskbadger.celery.get_task") as get_task:
326+
with mock.patch("taskbadger.sdk.get_task") as get_task:
327327
assert self.taskbadger_task_id is None
328328
assert Badger.current.session().client is None
329329
get_task.assert_not_called()
@@ -351,7 +351,7 @@ def add_shared_task(self, a, b):
351351
with (
352352
mock.patch("taskbadger.celery.create_task_safe") as create,
353353
mock.patch("taskbadger.celery.update_task_safe") as update,
354-
mock.patch("taskbadger.celery.get_task"),
354+
mock.patch("taskbadger.sdk.get_task"),
355355
):
356356
create.return_value = task_for_test()
357357

@@ -378,7 +378,7 @@ def task_signature(self, a):
378378
with (
379379
mock.patch("taskbadger.celery.create_task_safe") as create,
380380
mock.patch("taskbadger.celery.update_task_safe") as update,
381-
mock.patch("taskbadger.celery.get_task") as get_task,
381+
mock.patch("taskbadger.sdk.get_task") as get_task,
382382
):
383383
create.return_value = task_for_test()
384384

@@ -410,7 +410,7 @@ def task_map_fn(self, a):
410410
with (
411411
mock.patch("taskbadger.celery.create_task_safe") as create,
412412
mock.patch("taskbadger.celery.update_task_safe") as update,
413-
mock.patch("taskbadger.celery.get_task"),
413+
mock.patch("taskbadger.sdk.get_task"),
414414
):
415415
tb_task = task_for_test()
416416
create.return_value = tb_task
@@ -443,7 +443,7 @@ def task_starmap_fn(self, a, b):
443443
with (
444444
mock.patch("taskbadger.celery.create_task_safe") as create,
445445
mock.patch("taskbadger.celery.update_task_safe") as update,
446-
mock.patch("taskbadger.celery.get_task"),
446+
mock.patch("taskbadger.sdk.get_task"),
447447
):
448448
tb_task = task_for_test()
449449
create.return_value = tb_task
@@ -476,7 +476,7 @@ def task_chunks_fn(self, a):
476476
with (
477477
mock.patch("taskbadger.celery.create_task_safe") as create,
478478
mock.patch("taskbadger.celery.update_task_safe") as update,
479-
mock.patch("taskbadger.celery.get_task"),
479+
mock.patch("taskbadger.sdk.get_task"),
480480
):
481481
tb_task = task_for_test()
482482
create.return_value = tb_task
@@ -506,7 +506,7 @@ def add_manual_update(self, a, b, is_retry=False):
506506
with (
507507
mock.patch("taskbadger.celery.create_task_safe") as create,
508508
mock.patch("taskbadger.celery.update_task_safe") as update,
509-
mock.patch("taskbadger.celery.get_task") as get_task,
509+
mock.patch("taskbadger.sdk.get_task") as get_task,
510510
):
511511
create.return_value = task_for_test()
512512

tests/test_celery_error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def add_error(self, a, b):
2121
with (
2222
mock.patch("taskbadger.celery.create_task_safe") as create,
2323
mock.patch("taskbadger.celery.update_task_safe") as update,
24-
mock.patch("taskbadger.celery.get_task") as get_task,
24+
mock.patch("taskbadger.sdk.get_task") as get_task,
2525
):
2626
task = task_for_test()
2727
create.return_value = task

tests/test_celery_system_integration.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def add_normal(self, a, b):
7777
with (
7878
mock.patch("taskbadger.celery.create_task_safe") as create,
7979
mock.patch("taskbadger.celery.update_task_safe") as update,
80-
mock.patch("taskbadger.celery.get_task") as get_task,
80+
mock.patch("taskbadger.sdk.get_task") as get_task,
8181
):
8282
tb_task = task_for_test()
8383
create.return_value = tb_task
@@ -109,7 +109,7 @@ def add_normal(self, a, b):
109109
with (
110110
mock.patch("taskbadger.celery.create_task_safe") as create,
111111
mock.patch("taskbadger.celery.update_task_safe") as update,
112-
mock.patch("taskbadger.celery.get_task") as get_task,
112+
mock.patch("taskbadger.sdk.get_task") as get_task,
113113
):
114114
tb_task = task_for_test()
115115
create.return_value = tb_task
@@ -144,7 +144,7 @@ def add_normal_with_override(a, b):
144144
with (
145145
mock.patch("taskbadger.celery.create_task_safe") as create,
146146
mock.patch("taskbadger.celery.update_task_safe"),
147-
mock.patch("taskbadger.celery.get_task"),
147+
mock.patch("taskbadger.sdk.get_task"),
148148
):
149149
tb_task = task_for_test()
150150
create.return_value = tb_task
@@ -169,7 +169,7 @@ def add_with_tags(a, b):
169169
with (
170170
mock.patch("taskbadger.sdk.task_create.sync_detailed") as create,
171171
mock.patch("taskbadger.celery.update_task_safe"),
172-
mock.patch("taskbadger.celery.get_task"),
172+
mock.patch("taskbadger.sdk.get_task"),
173173
):
174174
tb_task = task_for_test()
175175
create.return_value = Response(

tests/test_procrastinate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def add(a, b):
4444

4545
with (
4646
mock.patch("taskbadger.procrastinate.update_task_safe") as update,
47-
mock.patch("taskbadger.procrastinate.get_task") as get,
47+
mock.patch("taskbadger.sdk.get_task") as get,
4848
):
4949
get.return_value = task_for_test(status=StatusEnum.PROCESSING)
5050
add.func(a=2, b=3, **{TB_TASK_ID_KWARG: "tb-123"})
@@ -65,7 +65,7 @@ async def add_async(a, b):
6565

6666
with (
6767
mock.patch("taskbadger.procrastinate.update_task_safe") as update,
68-
mock.patch("taskbadger.procrastinate.get_task") as get,
68+
mock.patch("taskbadger.sdk.get_task") as get,
6969
):
7070
get.return_value = task_for_test(status=StatusEnum.PROCESSING)
7171
result = asyncio.run(add_async.func(a=2, b=3, **{TB_TASK_ID_KWARG: "tb-456"}))
@@ -85,7 +85,7 @@ def boom():
8585

8686
with (
8787
mock.patch("taskbadger.procrastinate.update_task_safe") as update,
88-
mock.patch("taskbadger.procrastinate.get_task") as get,
88+
mock.patch("taskbadger.sdk.get_task") as get,
8989
):
9090
get.return_value = task_for_test(status=StatusEnum.PROCESSING, data={"x": 1})
9191
update.return_value = task_for_test(status=StatusEnum.PROCESSING, data={"x": 1})
@@ -179,7 +179,7 @@ def add6(a, b):
179179
with (
180180
mock.patch("taskbadger.procrastinate.create_task_safe", return_value=tb) as create,
181181
mock.patch("taskbadger.procrastinate.update_task_safe") as update,
182-
mock.patch("taskbadger.procrastinate.get_task") as get,
182+
mock.patch("taskbadger.sdk.get_task") as get,
183183
):
184184
get.return_value = task_for_test(id=tb.id, status=StatusEnum.PROCESSING)
185185
add6.defer(a=2, b=3)
@@ -270,7 +270,7 @@ def capture():
270270
with (
271271
mock.patch("taskbadger.procrastinate.create_task_safe", return_value=tb),
272272
mock.patch("taskbadger.procrastinate.update_task_safe", return_value=tb),
273-
mock.patch("taskbadger.procrastinate.get_task", return_value=tb),
273+
mock.patch("taskbadger.sdk.get_task", return_value=tb),
274274
):
275275
capture.defer()
276276
app.run_worker(wait=False, install_signal_handlers=False, listen_notify=False)
@@ -296,7 +296,7 @@ def self_complete():
296296
with (
297297
mock.patch("taskbadger.procrastinate.create_task_safe", return_value=tb_pending),
298298
mock.patch("taskbadger.procrastinate.update_task_safe") as update,
299-
mock.patch("taskbadger.procrastinate.get_task", return_value=tb_done),
299+
mock.patch("taskbadger.sdk.get_task", return_value=tb_done),
300300
):
301301
self_complete.defer()
302302
app.run_worker(wait=False, install_signal_handlers=False, listen_notify=False)
@@ -340,7 +340,7 @@ def ctx_task(context, a):
340340
with (
341341
mock.patch("taskbadger.procrastinate.create_task_safe", return_value=tb),
342342
mock.patch("taskbadger.procrastinate.update_task_safe"),
343-
mock.patch("taskbadger.procrastinate.get_task", return_value=tb),
343+
mock.patch("taskbadger.sdk.get_task", return_value=tb),
344344
):
345345
ctx_task.defer(a=42)
346346
app.run_worker(wait=False, install_signal_handlers=False, listen_notify=False)

0 commit comments

Comments
 (0)