|
9 | 9 | from tilebox.workflows import ExecutionContext, Task |
10 | 10 | from tilebox.workflows.cache import InMemoryCache, JobCache |
11 | 11 | from tilebox.workflows.client import Client |
12 | | -from tilebox.workflows.data import JobState, ProgressIndicator, RunnerContext |
| 12 | +from tilebox.workflows.data import JobState, ProgressIndicator, RunnerContext, TaskState |
13 | 13 | from tilebox.workflows.runner.task_runner import TaskRunner |
14 | 14 |
|
15 | 15 |
|
@@ -243,3 +243,79 @@ def test_runner_disallow_duplicate_task_identifiers() -> None: |
243 | 243 | ), |
244 | 244 | ): |
245 | 245 | runner.register(ExplicitIdentifierTaskV2) |
| 246 | + |
| 247 | + |
| 248 | +class OptionalSubbranch(Task): |
| 249 | + def execute(self, context: ExecutionContext) -> None: |
| 250 | + context.submit_subtask(OptionalSubtasks(False), optional=True) |
| 251 | + context.submit_subtask(SucceedingTask()) |
| 252 | + |
| 253 | + |
| 254 | +class OptionalSubtasks(Task): |
| 255 | + failing_task_optional: bool |
| 256 | + |
| 257 | + def execute(self, context: ExecutionContext) -> None: |
| 258 | + f = context.submit_subtask(FailingTask(), optional=self.failing_task_optional) |
| 259 | + context.submit_subtask(SucceedingTask(), depends_on=[f]) |
| 260 | + |
| 261 | + |
| 262 | +class FailingTask(Task): |
| 263 | + def execute(self, context: ExecutionContext) -> None: |
| 264 | + cache = context.job_cache # ty: ignore[unresolved-attribute] |
| 265 | + cache["failing_task"] = b"1" # to make sure it actually ran |
| 266 | + raise ValueError("This task always fails") |
| 267 | + |
| 268 | + |
| 269 | +class SucceedingTask(Task): |
| 270 | + def execute(self, context: ExecutionContext) -> None: |
| 271 | + cache = context.job_cache # ty: ignore[unresolved-attribute] |
| 272 | + cache["succeeding_task"] = b"1" # to make sure it actually ran |
| 273 | + |
| 274 | + |
| 275 | +def test_runner_optional_subbranch() -> None: |
| 276 | + client = replay_client("optional_subbranch.rpcs.bin") |
| 277 | + job_client = client.jobs() |
| 278 | + |
| 279 | + with patch("tilebox.workflows.jobs.client.get_trace_parent_of_current_span") as get_trace_parent_mock: |
| 280 | + # we hardcode the trace parent for the job, which allows us to assert that every single outgoing request |
| 281 | + # matches exactly byte for byte |
| 282 | + get_trace_parent_mock.return_value = "00-42fe17a0cc6752adf16a5a326d37f51c-795dd6a3bc5a0b81-01" |
| 283 | + job = client.jobs().submit("optional-subbranch-test", OptionalSubbranch()) |
| 284 | + |
| 285 | + cache = InMemoryCache() |
| 286 | + runner = client.runner(tasks=[OptionalSubbranch, OptionalSubtasks, FailingTask, SucceedingTask], cache=cache) |
| 287 | + |
| 288 | + runner.run_all() |
| 289 | + job = job_client.find(job) # load current job state |
| 290 | + assert job.state == JobState.COMPLETED |
| 291 | + |
| 292 | + assert job.execution_stats.tasks_by_state[TaskState.COMPUTED] == 3 |
| 293 | + assert job.execution_stats.tasks_by_state[TaskState.FAILED_OPTIONAL] == 1 |
| 294 | + assert job.execution_stats.tasks_by_state[TaskState.SKIPPED] == 1 |
| 295 | + |
| 296 | + assert cache.group(str(job.id))["failing_task"] == b"1" |
| 297 | + assert cache.group(str(job.id))["succeeding_task"] == b"1" |
| 298 | + |
| 299 | + |
| 300 | +def test_runner_optional_subtask() -> None: |
| 301 | + client = replay_client("optional_subtask.rpcs.bin") |
| 302 | + job_client = client.jobs() |
| 303 | + |
| 304 | + with patch("tilebox.workflows.jobs.client.get_trace_parent_of_current_span") as get_trace_parent_mock: |
| 305 | + # we hardcode the trace parent for the job, which allows us to assert that every single outgoing request |
| 306 | + # matches exactly byte for byte |
| 307 | + get_trace_parent_mock.return_value = "00-154ffe629cc5b746584825bfbb37963d-3ed10512af70309c-01" |
| 308 | + job = client.jobs().submit("optional-subtasks-test", OptionalSubtasks(True)) |
| 309 | + |
| 310 | + cache = InMemoryCache() |
| 311 | + runner = client.runner(tasks=[OptionalSubtasks, FailingTask, SucceedingTask], cache=cache) |
| 312 | + |
| 313 | + runner.run_all() |
| 314 | + job = job_client.find(job) # load current job state |
| 315 | + assert job.state == JobState.COMPLETED |
| 316 | + |
| 317 | + assert job.execution_stats.tasks_by_state[TaskState.COMPUTED] == 2 |
| 318 | + assert job.execution_stats.tasks_by_state[TaskState.FAILED_OPTIONAL] == 1 |
| 319 | + |
| 320 | + assert cache.group(str(job.id))["failing_task"] == b"1" |
| 321 | + assert cache.group(str(job.id))["succeeding_task"] == b"1" |
0 commit comments