Skip to content

Commit e3605c1

Browse files
Don't fail tasks on task runner requested shutdown (#30)
* Rename cancel job flag in TaskFailed request to was_workflow_error * Fix ty pre commit hook
1 parent c9def48 commit e3605c1

8 files changed

Lines changed: 107 additions & 24 deletions

File tree

.pre-commit-config.yaml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@ repos:
1616
- id: ruff-format
1717
- repo: local
1818
hooks:
19-
- id: ty-check
19+
- id: ty
2020
name: ty-check
21+
entry: uv run ty check
2122
language: python
22-
entry: ty check
23-
pass_filenames: false
24-
args: [--python=.venv/]
25-
additional_dependencies: [ty]
23+
types: [python]
24+
pass_filenames: true

tilebox-workflows/tests/runner/test_runner.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tilebox.workflows import ExecutionContext, Task
1010
from tilebox.workflows.cache import InMemoryCache, JobCache
1111
from tilebox.workflows.client import Client
12-
from tilebox.workflows.data import JobState, ProgressIndicator, RunnerContext
12+
from tilebox.workflows.data import JobState, ProgressIndicator, RunnerContext, TaskState
1313
from tilebox.workflows.runner.task_runner import TaskRunner
1414

1515

@@ -243,3 +243,79 @@ def test_runner_disallow_duplicate_task_identifiers() -> None:
243243
),
244244
):
245245
runner.register(ExplicitIdentifierTaskV2)
246+
247+
248+
class OptionalSubbranch(Task):
249+
def execute(self, context: ExecutionContext) -> None:
250+
context.submit_subtask(OptionalSubtasks(False), optional=True)
251+
context.submit_subtask(SucceedingTask())
252+
253+
254+
class OptionalSubtasks(Task):
255+
failing_task_optional: bool
256+
257+
def execute(self, context: ExecutionContext) -> None:
258+
f = context.submit_subtask(FailingTask(), optional=self.failing_task_optional)
259+
context.submit_subtask(SucceedingTask(), depends_on=[f])
260+
261+
262+
class FailingTask(Task):
263+
def execute(self, context: ExecutionContext) -> None:
264+
cache = context.job_cache # ty: ignore[unresolved-attribute]
265+
cache["failing_task"] = b"1" # to make sure it actually ran
266+
raise ValueError("This task always fails")
267+
268+
269+
class SucceedingTask(Task):
270+
def execute(self, context: ExecutionContext) -> None:
271+
cache = context.job_cache # ty: ignore[unresolved-attribute]
272+
cache["succeeding_task"] = b"1" # to make sure it actually ran
273+
274+
275+
def test_runner_optional_subbranch() -> None:
276+
client = replay_client("optional_subbranch.rpcs.bin")
277+
job_client = client.jobs()
278+
279+
with patch("tilebox.workflows.jobs.client.get_trace_parent_of_current_span") as get_trace_parent_mock:
280+
# we hardcode the trace parent for the job, which allows us to assert that every single outgoing request
281+
# matches exactly byte for byte
282+
get_trace_parent_mock.return_value = "00-42fe17a0cc6752adf16a5a326d37f51c-795dd6a3bc5a0b81-01"
283+
job = client.jobs().submit("optional-subbranch-test", OptionalSubbranch())
284+
285+
cache = InMemoryCache()
286+
runner = client.runner(tasks=[OptionalSubbranch, OptionalSubtasks, FailingTask, SucceedingTask], cache=cache)
287+
288+
runner.run_all()
289+
job = job_client.find(job) # load current job state
290+
assert job.state == JobState.COMPLETED
291+
292+
assert job.execution_stats.tasks_by_state[TaskState.COMPUTED] == 3
293+
assert job.execution_stats.tasks_by_state[TaskState.FAILED_OPTIONAL] == 1
294+
assert job.execution_stats.tasks_by_state[TaskState.SKIPPED] == 1
295+
296+
assert cache.group(str(job.id))["failing_task"] == b"1"
297+
assert cache.group(str(job.id))["succeeding_task"] == b"1"
298+
299+
300+
def test_runner_optional_subtask() -> None:
301+
client = replay_client("optional_subtask.rpcs.bin")
302+
job_client = client.jobs()
303+
304+
with patch("tilebox.workflows.jobs.client.get_trace_parent_of_current_span") as get_trace_parent_mock:
305+
# we hardcode the trace parent for the job, which allows us to assert that every single outgoing request
306+
# matches exactly byte for byte
307+
get_trace_parent_mock.return_value = "00-154ffe629cc5b746584825bfbb37963d-3ed10512af70309c-01"
308+
job = client.jobs().submit("optional-subtasks-test", OptionalSubtasks(True))
309+
310+
cache = InMemoryCache()
311+
runner = client.runner(tasks=[OptionalSubtasks, FailingTask, SucceedingTask], cache=cache)
312+
313+
runner.run_all()
314+
job = job_client.find(job) # load current job state
315+
assert job.state == JobState.COMPLETED
316+
317+
assert job.execution_stats.tasks_by_state[TaskState.COMPUTED] == 2
318+
assert job.execution_stats.tasks_by_state[TaskState.FAILED_OPTIONAL] == 1
319+
320+
assert cache.group(str(job.id))["failing_task"] == b"1"
321+
assert cache.group(str(job.id))["succeeding_task"] == b"1"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:cac0cd1cd31ec5c62949704159a113d2f69a249c35997eac113a2ac695d196d5
3+
size 4430
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:3eee76fab8b6018fbba6e0332a8a5a193a1095ceab3e4a0b2e2b586a5d520f9b
3+
size 3526

tilebox-workflows/tilebox/workflows/runner/task_runner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _external_interrupt_handler(self, signum: int, frame: FrameType | None) -> N
232232
self._service.task_failed(
233233
self._task,
234234
RunnerShutdown("Task was interrupted"),
235-
cancel_job=False,
235+
was_workflow_error=False,
236236
progress_updates=progress,
237237
)
238238

@@ -441,9 +441,11 @@ def _execute(self, task: Task, shutdown_context: _GracefulShutdown) -> Task | Id
441441
self.logger.exception(f"Task {task_repr} failed!")
442442

443443
task_failed_retry = _retry_backoff(self._service.task_failed, stop=shutdown_context.stop_if_shutting_down())
444-
cancel_job = True
445-
progress_updates = _finalize_mutable_progress_trackers(context._progress_indicators) # noqa: SLF001
446-
task_failed_retry(task, e, cancel_job, progress_updates)
444+
was_workflow_error = True
445+
progress_updates: list[ProgressIndicator] = _finalize_mutable_progress_trackers(
446+
context._progress_indicators # noqa: SLF001
447+
)
448+
task_failed_retry(task, e, was_workflow_error, progress_updates)
447449

448450
return None
449451

tilebox-workflows/tilebox/workflows/runner/task_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ def next_task(self, task_to_run: NextTaskToRun | None, computed_task: ComputedTa
4949
return None
5050

5151
def task_failed(
52-
self, task: Task, error: Exception, cancel_job: bool, progress_updates: list[ProgressIndicator]
52+
self, task: Task, error: Exception, was_workflow_error: bool, progress_updates: list[ProgressIndicator]
5353
) -> None:
5454
# job ouptut is limited to 1KB, so truncate the error message if necessary
5555
error_message = repr(error)[: (1024 - len(task.display or "None") - 1)]
5656
display = f"{task.display}" if error_message == "" else f"{task.display}\n{error_message}"
5757

5858
request = TaskFailedRequest(
5959
task_id=uuid_to_uuid_message(task.id),
60-
cancel_job=cancel_job,
60+
was_workflow_error=was_workflow_error,
6161
display=display,
6262
progress_updates=[progress.to_message() for progress in progress_updates],
6363
)

tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.py

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tilebox-workflows/tilebox/workflows/workflows/v1/task_pb2.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ class NextTaskResponse(_message.Message):
5555
def __init__(self, next_task: _Optional[_Union[_core_pb2.Task, _Mapping]] = ..., idling: _Optional[_Union[IdlingResponse, _Mapping]] = ...) -> None: ...
5656

5757
class TaskFailedRequest(_message.Message):
58-
__slots__ = ("task_id", "display", "cancel_job", "progress_updates")
58+
__slots__ = ("task_id", "display", "was_workflow_error", "progress_updates")
5959
TASK_ID_FIELD_NUMBER: _ClassVar[int]
6060
DISPLAY_FIELD_NUMBER: _ClassVar[int]
61-
CANCEL_JOB_FIELD_NUMBER: _ClassVar[int]
61+
WAS_WORKFLOW_ERROR_FIELD_NUMBER: _ClassVar[int]
6262
PROGRESS_UPDATES_FIELD_NUMBER: _ClassVar[int]
6363
task_id: _id_pb2.ID
6464
display: str
65-
cancel_job: bool
65+
was_workflow_error: bool
6666
progress_updates: _containers.RepeatedCompositeFieldContainer[_core_pb2.Progress]
67-
def __init__(self, task_id: _Optional[_Union[_id_pb2.ID, _Mapping]] = ..., display: _Optional[str] = ..., cancel_job: bool = ..., progress_updates: _Optional[_Iterable[_Union[_core_pb2.Progress, _Mapping]]] = ...) -> None: ...
67+
def __init__(self, task_id: _Optional[_Union[_id_pb2.ID, _Mapping]] = ..., display: _Optional[str] = ..., was_workflow_error: bool = ..., progress_updates: _Optional[_Iterable[_Union[_core_pb2.Progress, _Mapping]]] = ...) -> None: ...
6868

6969
class TaskStateResponse(_message.Message):
7070
__slots__ = ("state",)

0 commit comments

Comments
 (0)