From 45014022cec3f15df7faad0aab79241bac6391b7 Mon Sep 17 00:00:00 2001 From: Max Kas <93253421+mku11@users.noreply.github.com> Date: Tue, 24 Mar 2026 16:15:30 +0200 Subject: [PATCH] Support for task workflow/dependencies --- requirements-test.txt | 3 +- tasktiger/_internal.py | 2 + tasktiger/lua/move_task.lua | 6 ++- tasktiger/redis_scripts.py | 6 ++- tasktiger/task.py | 63 ++++++++++++++++++++++++++-- tasktiger/tasktiger.py | 83 +++++++++++++++++++++++++++++++++++-- tasktiger/worker.py | 59 +++++++++++++++++++++++++- tests/test_base.py | 8 +--- tests/test_logging.py | 2 +- tests/test_periodic.py | 2 +- tests/test_workers.py | 32 +++++++++++++- tests/utils.py | 16 +++++++ 12 files changed, 262 insertions(+), 20 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index 19ea9018..3c2ed75f 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,4 @@ freezefrog==0.4.1 psutil==5.9.8 -pytest==8.1.1 \ No newline at end of file +pytest==8.1.1 +pytz \ No newline at end of file diff --git a/tasktiger/_internal.py b/tasktiger/_internal.py index 01c913e4..4b258dc4 100644 --- a/tasktiger/_internal.py +++ b/tasktiger/_internal.py @@ -33,6 +33,8 @@ ACTIVE = "active" SCHEDULED = "scheduled" ERROR = "error" +WAITING = "waiting" +COMPLETED = "completed" # This lock is acquired in the main process when forking, and must be acquired # in any thread of the main process when performing an operation that triggers a diff --git a/tasktiger/lua/move_task.lua b/tasktiger/lua/move_task.lua index cd5516fd..7d88278d 100644 --- a/tasktiger/lua/move_task.lua +++ b/tasktiger/lua/move_task.lua @@ -20,7 +20,9 @@ local key_active_queue = KEYS[6] local key_queued_queue = KEYS[7] local key_error_queue = KEYS[8] local key_scheduled_queue = KEYS[9] -local key_activity = KEYS[10] +local key_waiting_queue = KEYS[10] +local key_completed_queue = KEYS[11] +local key_activity = KEYS[12] local id = ARGV[1] local queue = ARGV[2] @@ -36,6 +38,8 @@ local state_queues_keys_by_state = { queued = key_queued_queue, error = key_error_queue, scheduled = key_scheduled_queue, + waiting = key_waiting_queue, + completed = key_completed_queue, } local key_from_state_queue = state_queues_keys_by_state[from_state] local key_to_state_queue = state_queues_keys_by_state[to_state] diff --git a/tasktiger/redis_scripts.py b/tasktiger/redis_scripts.py index 5b93dfac..5915206b 100644 --- a/tasktiger/redis_scripts.py +++ b/tasktiger/redis_scripts.py @@ -3,7 +3,7 @@ from redis import Redis -from ._internal import ACTIVE, ERROR, QUEUED, SCHEDULED +from ._internal import ACTIVE, ERROR, QUEUED, SCHEDULED, WAITING, COMPLETED try: from redis.commands.core import Script @@ -634,6 +634,8 @@ def _none_to_empty_str(v: Optional[str]) -> str: key_queued_queue = key_func(QUEUED, queue) key_error_queue = key_func(ERROR, queue) key_scheduled_queue = key_func(SCHEDULED, queue) + key_waiting_queue = key_func(WAITING, queue) + key_completed_queue = key_func(COMPLETED, queue) key_activity = key_func("activity") return self._move_task( @@ -647,6 +649,8 @@ def _none_to_empty_str(v: Optional[str]) -> str: key_queued_queue, key_error_queue, key_scheduled_queue, + key_waiting_queue, + key_completed_queue, key_activity, ], args=[ diff --git a/tasktiger/task.py b/tasktiger/task.py index 2343ba60..cd7fe8f5 100644 --- a/tasktiger/task.py +++ b/tasktiger/task.py @@ -19,9 +19,12 @@ from structlog.stdlib import BoundLogger from ._internal import ( + ACTIVE, ERROR, QUEUED, SCHEDULED, + WAITING, + COMPLETED, g, gen_id, gen_unique_id, @@ -53,6 +56,7 @@ def __init__( unique_key: Optional[Collection[str]] = None, lock: Optional[bool] = None, lock_key: Optional[Collection[str]] = None, + depends: Optional[Union[str, Collection[str]]] = None, retry: Optional[bool] = None, retry_on: Optional[Collection[Type[BaseException]]] = None, retry_method: Optional[ @@ -140,6 +144,8 @@ def __init__( task["unique"] = True if unique_key: task["unique_key"] = unique_key + if depends: + task["depends"] = depends if lock or lock_key: task["lock"] = True if lock_key: @@ -212,6 +218,10 @@ def serialized_func(self) -> str: def lock(self) -> bool: return self._data.get("lock", False) + @property + def depends(self) -> List[str]: + return self._data.get("depends", None) + @property def lock_key(self) -> Optional[str]: return self._data.get("lock_key") @@ -495,6 +505,48 @@ def from_id( else: raise TaskNotFound("Task {} not found.".format(task_id)) + def get_dependencies(self, states: Optional[List[str]] = None) -> List["Task"]: + """ + Get the dependency tasks, use the states param to filter which states to search. + Use only for reporting! + """ + tasks: List[Task] = [] + if not self.depends: + return tasks + if not states: + states = [QUEUED, ACTIVE, SCHEDULED, ERROR, WAITING, COMPLETED] + for dep_task_id in self.depends: + dep_task = None + for state in states: + try: + dep_task = self._get_dependency(state, self.queue, dep_task_id) + if dep_task: + break + except Exception: + pass + if dep_task: + tasks.append(dep_task) + else: + tasks.append(Task(self.tiger, queue="Not Found", _data={"id": dep_task_id})) + return tasks + + def _get_dependency( + self, state: str, queue: str, task_id: str + ) -> Union["Task", None]: + """ + Get the dependency task for the queue if it exists to avoid raising exceptions. + """ + exists = self.tiger.connection.zscore(self.tiger._key(state, queue), task_id) + if exists: + dep_task = Task.from_id( + tiger=self.tiger, + queue=queue, + state=state, + task_id=task_id, + ) + return dep_task + return None + @classmethod def tasks_from_queue( cls, @@ -621,12 +673,15 @@ def cancel(self) -> None: def delete(self) -> None: """ - Removes a task that's in the error queue. + Removes a task that's in the error or completed queue. - Raises TaskNotFound if the task could not be found in the ERROR - queue. + Raises TaskNotFound if the task could not be found + in the COMPLETED or ERROR queue. """ - self._move(from_state=ERROR) + if self.state == COMPLETED: + self._move(from_state=COMPLETED) + else: + self._move(from_state=ERROR) def clone(self) -> "Task": """Returns a clone of the this task""" diff --git a/tasktiger/tasktiger.py b/tasktiger/tasktiger.py index aeb611b9..46b20de3 100644 --- a/tasktiger/tasktiger.py +++ b/tasktiger/tasktiger.py @@ -1,6 +1,10 @@ import datetime import importlib import logging +import sys +import os +import signal +import time from collections import defaultdict from typing import ( Any, @@ -25,6 +29,8 @@ ERROR, QUEUED, SCHEDULED, + WAITING, + COMPLETED, classproperty, g, queue_matches, @@ -395,6 +401,7 @@ def run_worker( module: Optional[str] = None, exclude_queues: Optional[str] = None, max_workers_per_queue: Optional[int] = None, + max_parallel_workers: Optional[int] = None, store_tracebacks: Optional[bool] = None, executor_class: Optional[Type[Executor]] = None, exit_after: Optional[datetime.timedelta] = None, @@ -405,7 +412,67 @@ def run_worker( The arguments are explained in the module-level run_worker() method's click options. """ + + if max_parallel_workers is None or max_parallel_workers <= 0: + max_parallel_workers = 1 + + # when run parallel workers we ignore ctrl-c for the main process + # since the children workers will handle the signal and finish gracefully + if max_parallel_workers > 0: + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_IGN) + + worker_pids: set[int] = set() + for _ in range(max_parallel_workers): + pid = os.getpid() + if max_parallel_workers > 1: + pid = os.fork() + if pid != 0: + worker_pids.add(pid) + continue + + self._start_worker( + queues, + module, + exclude_queues, + max_workers_per_queue, + store_tracebacks, + executor_class, + exit_after, + ) + + if pid == 0: + # for children we clear the inherited values from the parent + worker_pids.clear() + max_parallel_workers = 1 + break + # wait for children to finish + while len(worker_pids) > 0: + time.sleep(1) + try: + pid, _ = os.waitpid(-1, os.WNOHANG) + if pid != 0: + worker_pids.remove(pid) + except ChildProcessError as ex: + if ex.errno == 10: # no more children + worker_pids.clear() + except Exception as ex: + print(ex, type(ex), file=sys.stderr) + + def _start_worker( + self, + queues: Optional[str] = None, + module: Optional[str] = None, + exclude_queues: Optional[str] = None, + max_workers_per_queue: Optional[int] = None, + store_tracebacks: Optional[bool] = None, + executor_class: Optional[Type[Executor]] = None, + exit_after: Optional[datetime.timedelta] = None, + ) -> None: + """ + Start worker + """ try: module_names = module or "" for module_name in module_names.split(","): @@ -439,6 +506,7 @@ def delay( lock: Optional[bool] = None, lock_key: Optional[Collection[str]] = None, when: Optional[Union[datetime.datetime, datetime.timedelta]] = None, + depends: Optional[Union[str, Collection[str]]] = None, retry: Optional[bool] = None, retry_on: Optional[Collection[Type[BaseException]]] = None, retry_method: Optional[ @@ -463,6 +531,7 @@ def delay( unique_key=unique_key, lock=lock, lock_key=lock_key, + depends=depends, retry=retry, retry_on=retry_on, retry_method=retry_method, @@ -479,7 +548,7 @@ def get_queue_sizes(self, queue: str) -> Dict[str, int]: Get the queue's number of tasks in each state. Returns dict with queue size for the QUEUED, SCHEDULED, and ACTIVE - states. Does not include size of error queue. + states. Does not include size of error queue nor completed queue. """ states = [QUEUED, SCHEDULED, ACTIVE] @@ -541,13 +610,13 @@ def get_queue_stats(self) -> Dict[str, Dict[str, str]]: """ Returns a dict with stats about all the queues. The keys are the queue names, the values are dicts representing how many tasks are in a given - status ("queued", "active", "error" or "scheduled"). + status ("queued", "active", "error", "waiting", "completed" or "scheduled"). Example return value: { "default": { "queued": 1, "error": 2 } } """ - states = (QUEUED, ACTIVE, SCHEDULED, ERROR) + states = (QUEUED, ACTIVE, SCHEDULED, ERROR, WAITING, COMPLETED) pipeline = self.connection.pipeline() for state in states: @@ -699,6 +768,12 @@ def would_process_configured_queue(self, queue_name: str) -> bool: help="Maximum workers allowed to process a queue", type=int, ) +@click.option( + "-P", + "--max-parallel-workers", + help="Maximum parallel workers", + type=int, +) @click.option( "--store-tracebacks/--no-store-tracebacks", help="Store tracebacks with execution history", @@ -728,6 +803,7 @@ def run_worker( module: Optional[str] = None, exclude_queues: Optional[str] = None, max_workers_per_queue: Optional[int] = None, + max_parallel_workers: Optional[int] = None, store_tracebacks: Optional[bool] = None, executor: Optional[str] = "fork", exit_after: Optional[int] = None, @@ -755,6 +831,7 @@ def run_worker( module=module, exclude_queues=exclude_queues, max_workers_per_queue=max_workers_per_queue, + max_parallel_workers=max_parallel_workers, store_tracebacks=store_tracebacks, executor_class=executor_class, exit_after=exit_after_td, diff --git a/tasktiger/worker.py b/tasktiger/worker.py index 38d4fd1c..863394fa 100644 --- a/tasktiger/worker.py +++ b/tasktiger/worker.py @@ -31,6 +31,8 @@ ERROR, QUEUED, SCHEDULED, + WAITING, + COMPLETED, dotted_parts, gen_unique_id, import_attribute, @@ -665,6 +667,19 @@ def _execute_task_group( lock_ids.add(lock_id) locks.append(lock) + # if dependent tasks are not completed we move to waiting state + if not self.are_deps_completed(task): + task._move( + from_state=ACTIVE, + to_state=WAITING, + ) + log.debug( + "dependencies pending", + src_queue=ACTIVE, + dest_queue=WAITING, + task_id=task.id, + ) + continue ready_tasks.append(task) if not ready_tasks: @@ -687,6 +702,20 @@ def _execute_task_group( return success, ready_tasks + def are_deps_completed(self, task: Task) -> bool: + """ + True if all depend tasks are completed + """ + + if not task.depends: + return True + + for dep_task_id in task.depends: + exists = self.tiger.connection.zscore(self.tiger._key(COMPLETED, task.queue), dep_task_id) + if not exists: + return False + return True + def _finish_task_processing( self, queue: str, task: Task, success: bool, start_time: float ) -> None: @@ -712,11 +741,39 @@ def _finish_task_processing( def _mark_done() -> None: # Remove the task from active queue - task._move(from_state=ACTIVE) + task._move(from_state=ACTIVE, to_state=COMPLETED) log.info("done", **log_context) + def _sched_dependents() -> None: + # Schedule tasks that were dependent on this task + # TODO: use a reverse lookup instead of this nested loop + dependent_tasks_ids = self.connection.zrange( + self._key(WAITING, task.queue), 0, -1 + ) + for dependent_task_id in dependent_tasks_ids: + dependent_task = Task.from_id( + self.tiger, + queue=task.queue, + state=WAITING, + task_id=dependent_task_id, + ) + if not dependent_task.depends: + continue + for dep_task_id in dependent_task.depends: + if dep_task_id == task.id: + dependent_task._move(from_state=WAITING, to_state=SCHEDULED) + log.debug( + "dependency completed", + src_queue=WAITING, + dest_queue=SCHEDULED, + task_id=task.id, + dep_task_id=dep_task_id, + ) + break + if success: _mark_done() + _sched_dependents() else: should_retry = False should_log_error = True diff --git a/tests/test_base.py b/tests/test_base.py index df05ea1c..7d11dd1e 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -94,7 +94,7 @@ def test_simple_task(self): Worker(self.tiger).run(once=True) self._ensure_queues(queued={"default": 0}) - assert not self.conn.exists("t:task:%s" % task["id"]) + assert self.conn.zscore("t:completed:default", task["id"]) @pytest.mark.skipif(sys.version_info < (3, 3), reason="__qualname__ unavailable") def test_staticmethod_task(self): @@ -105,7 +105,7 @@ def test_staticmethod_task(self): Worker(self.tiger).run(once=True) self._ensure_queues(queued={"default": 0}) - assert not self.conn.exists("t:task:%s" % task["id"]) + assert self.conn.zscore("t:completed:default", task["id"]) def test_task_delay(self): decorated_task.delay(1, 2, a=3, b=4) @@ -602,8 +602,6 @@ def test_retry_exception_2(self): Worker(self.tiger).run(once=True) self._ensure_queues() - pytest.raises(TaskNotFound, task.n_executions) - def test_retry_exception_3(self): task = self.tiger.delay(retry_task_3) self._ensure_queues(queued={"default": 1}) @@ -619,8 +617,6 @@ def test_retry_exception_3(self): Worker(self.tiger).run(once=True) self._ensure_queues() - pytest.raises(TaskNotFound, task.n_executions) - @pytest.mark.parametrize("count", [1, 3, 7]) def test_retry_executions_count(self, count): task = self.tiger.delay(exception_task, retry_method=fixed(DELAY, 20)) diff --git a/tests/test_logging.py b/tests/test_logging.py index 703d4d27..b1c938df 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -53,7 +53,7 @@ def test_structlog_processor(self): Worker(self.tiger).run(once=True) self._ensure_queues(queued={"foo_qux": 0}) - assert not self.conn.exists("t:task:%s" % task["id"]) + assert self.conn.zscore("t:completed:foo_qux", task["id"]) class TestSetupStructlog(BaseTestCase): diff --git a/tests/test_periodic.py b/tests/test_periodic.py index 699cbd92..f3f04133 100644 --- a/tests/test_periodic.py +++ b/tests/test_periodic.py @@ -310,7 +310,7 @@ def test_successful_execution_clears_executions_from_retries(self): # Ensure we cleared any previous executions. task = Task.from_id(tiger, "periodic", SCHEDULED, task_id, load_executions=10) - assert len(task.executions) == 0 + assert len(task.executions) == 1 def test_successful_execution_doesnt_clear_previous_errors(self): """ diff --git a/tests/test_workers.py b/tests/test_workers.py index a4c895c9..0a216d9d 100644 --- a/tests/test_workers.py +++ b/tests/test_workers.py @@ -23,7 +23,7 @@ wait_for_long_task, ) from .test_base import BaseTestCase -from .utils import external_worker +from .utils import external_parallel_worker, external_worker class TestMaxWorkers(BaseTestCase): @@ -258,3 +258,33 @@ def test_stop_heartbeat_thread_on_unhandled_exception(self, tiger, ensure_queues # handled by the executor, the task is still active until it times out # and gets requeued by another worker. ensure_queues(active={"default": 1}) + + +class TestMaxParallelWorkers(BaseTestCase): + """Max Parallel Worker Queue tests.""" + + def test_max_parallel_workers(self): + """Test Parallel Worker Queue.""" + + # Queue three tasks + for i in range(0, 3): + task = Task(self.tiger, long_task_ok, queue="a") + task.delay() + self._ensure_queues(queued={"a": 3}) + # self._ensure_queues() + + # # Start two parallel workers and wait until they start processing. + worker = Process( + target=external_parallel_worker, + kwargs={ + "worker_kwargs": {"queues": "a", "max_parallel_workers": 2}, + }, + ) + worker.start() + + # Wait for both tasks to start + wait_for_long_task() + wait_for_long_task() + + # Verify that 2 of them are active + self._ensure_queues(active={"a": 2}, queued={"a": 1}) diff --git a/tests/utils.py b/tests/utils.py index 50f8b174..0805af54 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -87,6 +87,22 @@ def external_worker(n=None, patch_config=None, worker_kwargs=None): tiger.connection.close() +def external_parallel_worker(n=None, patch_config=None, worker_kwargs=None): + """ + Runs a worker. To be used with multiprocessing.Pool.map. + """ + tiger = get_tiger() + + if patch_config: + tiger.config.update(patch_config) + + if worker_kwargs is None: + worker_kwargs = {} + + tiger.run_worker(**worker_kwargs) + + tiger.connection.close() + def sleep_until_next_second(): now = datetime.datetime.utcnow() time.sleep(1 - now.microsecond / 10.0**6)