diff --git a/flytekit/core/array_node.py b/flytekit/core/array_node.py index 8bd612ef59..b6e72f24d8 100644 --- a/flytekit/core/array_node.py +++ b/flytekit/core/array_node.py @@ -41,6 +41,7 @@ def __init__( min_successes: Optional[int] = None, min_success_ratio: Optional[float] = None, metadata: Optional[_workflow_model.NodeMetadata] = None, + run_all_sub_nodes: bool = False, ): """ :param target: The target Flyte entity to map over @@ -52,11 +53,13 @@ def __init__( min_success_ratio :param min_success_ratio: The minimum ratio of successful executions. :param metadata: The metadata for the underlying node + :param run_all_sub_nodes: If True, all sub-nodes will run to completion even after the failure threshold is met """ from flytekit.remote import FlyteLaunchPlan self.target = target self._concurrency = concurrency + self._run_all_sub_nodes = run_all_sub_nodes self.id = target.name self._bindings = bindings or [] self.metadata = metadata @@ -226,6 +229,10 @@ def min_successes(self) -> Optional[int]: def concurrency(self) -> Optional[int]: return self._concurrency + @property + def run_all_sub_nodes(self) -> bool: + return self._run_all_sub_nodes + @property def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode: return self._execution_mode @@ -275,6 +282,7 @@ def array_node( concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, min_successes: Optional[int] = None, + run_all_sub_nodes: bool = False, ): """ ArrayNode implementation that maps over tasks and other Flyte entities @@ -287,6 +295,7 @@ def array_node( :param min_successes: The minimum number of successful executions. If set, this takes precedence over min_success_ratio :param min_success_ratio: The minimum ratio of successful executions + :param run_all_sub_nodes: If True, all sub-nodes will run to completion even after the failure threshold is met :return: A callable function that takes in keyword arguments and returns a Promise created by flyte_entity_call_handler """ @@ -300,6 +309,7 @@ def array_node( concurrency=concurrency, min_successes=min_successes, min_success_ratio=min_success_ratio, + run_all_sub_nodes=run_all_sub_nodes, ) return node diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 9df005ac33..95ffe1dbfc 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -42,6 +42,7 @@ def __init__( min_success_ratio: Optional[float] = None, bound_inputs: Optional[Set[str]] = None, bound_inputs_values: Optional[Dict[str, Any]] = None, + run_all_sub_nodes: bool = False, **kwargs, ): """ @@ -51,6 +52,7 @@ def __init__( :param min_success_ratio: The minimum ratio of successful executions :param bound_inputs: The set of inputs that should be bound to the map task :param bound_inputs_values: Inputs that are bound to the array node and will not be mapped over + :param run_all_sub_nodes: If True, all sub-nodes will run to completion even after the failure threshold is met :param kwargs: Additional keyword arguments to pass to the base class """ self._partial = None @@ -113,6 +115,7 @@ def __init__( self._concurrency: Optional[int] = concurrency self._min_successes: Optional[int] = min_successes self._min_success_ratio: Optional[float] = min_success_ratio + self._run_all_sub_nodes: bool = run_all_sub_nodes self._collection_interface = collection_interface self._execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE @@ -168,6 +171,10 @@ def min_successes(self) -> Optional[int]: def concurrency(self) -> Optional[int]: return self._concurrency + @property + def run_all_sub_nodes(self) -> bool: + return self._run_all_sub_nodes + @property def python_function_task(self) -> Union[PythonFunctionTask, PythonInstanceTask]: return self._run_task @@ -385,6 +392,7 @@ def map_task( concurrency: Optional[int] = None, min_successes: Optional[int] = None, min_success_ratio: float = 1.0, + run_all_sub_nodes: bool = False, **kwargs, ): """ @@ -398,6 +406,7 @@ def map_task( array node will inherit parallelism from the workflow :param min_successes: The minimum number of successful executions :param min_success_ratio: The minimum ratio of successful executions + :param run_all_sub_nodes: If True, all sub-nodes will run to completion even after the failure threshold is met """ from flytekit.remote import FlyteLaunchPlan @@ -407,12 +416,14 @@ def map_task( concurrency=concurrency, min_successes=min_successes, min_success_ratio=min_success_ratio, + run_all_sub_nodes=run_all_sub_nodes, ) return array_node_map_task( task_function=target, concurrency=concurrency, min_successes=min_successes, min_success_ratio=min_success_ratio, + run_all_sub_nodes=run_all_sub_nodes, **kwargs, ) @@ -422,6 +433,7 @@ def array_node_map_task( concurrency: Optional[int] = None, # TODO why no min_successes? min_success_ratio: float = 1.0, + run_all_sub_nodes: bool = False, **kwargs, ): """Map task that uses the ``ArrayNode`` construct.. @@ -437,8 +449,15 @@ def array_node_map_task( array node will inherit parallelism from the workflow :param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete successfully before terminating this task and marking it successful. + :param run_all_sub_nodes: If True, all sub-nodes will run to completion even after the failure threshold is met. """ - return ArrayNodeMapTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs) + return ArrayNodeMapTask( + task_function, + concurrency=concurrency, + min_success_ratio=min_success_ratio, + run_all_sub_nodes=run_all_sub_nodes, + **kwargs, + ) class ArrayNodeMapTaskResolver(tracker.TrackedInstance, TaskResolverMixin): diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 813bbce70a..4d394c7e29 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -396,6 +396,7 @@ def __init__( is_original_sub_node_interface=False, data_mode=None, bound_inputs=None, + run_all_sub_nodes: bool = False, ) -> None: """ TODO: docstring @@ -409,6 +410,7 @@ def __init__( self._is_original_sub_node_interface = is_original_sub_node_interface self._data_mode = data_mode self._bound_inputs = bound_inputs + self._run_all_sub_nodes = run_all_sub_nodes @property def node(self) -> "Node": @@ -424,6 +426,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode: is_original_sub_node_interface=BoolValue(value=self._is_original_sub_node_interface), data_mode=self._data_mode, bound_inputs=sorted(self._bound_inputs) if self._bound_inputs else None, + run_all_sub_nodes=self._run_all_sub_nodes, ) @classmethod @@ -433,6 +436,7 @@ def from_flyte_idl(cls, pb2_object) -> "ArrayNode": pb2_object.parallelism, pb2_object.min_successes, pb2_object.min_success_ratio, + run_all_sub_nodes=pb2_object.run_all_sub_nodes, ) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 886ee289c7..e75ea30b21 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -635,6 +635,7 @@ def get_serializable_array_node( is_original_sub_node_interface=array_node.is_original_sub_node_interface, data_mode=array_node.data_mode, bound_inputs=array_node.bound_inputs, + run_all_sub_nodes=array_node.run_all_sub_nodes, ) @@ -671,6 +672,7 @@ def get_serializable_array_node_map_task( execution_mode=entity.execution_mode, is_original_sub_node_interface=entity.is_original_sub_node_interface, bound_inputs=entity.bound_inputs, + run_all_sub_nodes=entity.run_all_sub_nodes, ) diff --git a/tests/flytekit/unit/core/test_array_node.py b/tests/flytekit/unit/core/test_array_node.py index da3c0115f4..a817d1e9c0 100644 --- a/tests/flytekit/unit/core/test_array_node.py +++ b/tests/flytekit/unit/core/test_array_node.py @@ -378,3 +378,47 @@ def test_map_task_wrapper(): mapped_lp = map_task(lp)(a=[1, 3, 5], b=[2, 4, 6], c=[7, 8, 9]) assert mapped_lp == [14, 96, 270] + + +def test_run_all_sub_nodes_default(): + node = array_node(lp, concurrency=10, min_success_ratio=0.9) + assert node.run_all_sub_nodes is False + + +def test_run_all_sub_nodes_set(): + node = array_node(lp, concurrency=10, min_success_ratio=0.9, run_all_sub_nodes=True) + assert node.run_all_sub_nodes is True + + +def test_run_all_sub_nodes_serialization(serialization_settings): + @workflow + def wf_run_all() -> typing.List[int]: + return map_task(lp, concurrency=10, min_success_ratio=0.9, run_all_sub_nodes=True)( + a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9] + ) + + od = OrderedDict() + wf_spec = get_serializable(od, serialization_settings, wf_run_all) + + parent_node = wf_spec.template.nodes[0] + assert parent_node.array_node._run_all_sub_nodes is True + + pb = parent_node.array_node.to_flyte_idl() + assert pb.run_all_sub_nodes is True + + +def test_run_all_sub_nodes_serialization_default(serialization_settings): + @workflow + def wf_no_run_all() -> typing.List[int]: + return map_task(lp, concurrency=10, min_success_ratio=0.9)( + a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9] + ) + + od = OrderedDict() + wf_spec = get_serializable(od, serialization_settings, wf_no_run_all) + + parent_node = wf_spec.template.nodes[0] + assert parent_node.array_node._run_all_sub_nodes is False + + pb = parent_node.array_node.to_flyte_idl() + assert pb.run_all_sub_nodes is False diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index a2f35424e4..15f562969e 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -578,6 +578,7 @@ def wf1(x: typing.List[int]): assert array_node.array_node._parallelism == 10 assert not array_node.array_node._is_original_sub_node_interface assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.MINIMAL_STATE + assert not array_node.array_node._run_all_sub_nodes task_spec = od[arraynode_maptask] assert task_spec.template.metadata.retries.retries == 2 assert task_spec.template.metadata.interruptible @@ -588,6 +589,67 @@ def wf1(x: typing.List[int]): assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.FULL_STATE +def test_run_all_sub_nodes_default(): + @task + def t1(a: int) -> int: + return a + 1 + + mt = map_task(t1) + assert mt.run_all_sub_nodes is False + + +def test_run_all_sub_nodes_set(): + @task + def t1(a: int) -> int: + return a + 1 + + mt = map_task(t1, run_all_sub_nodes=True) + assert mt.run_all_sub_nodes is True + + +def test_run_all_sub_nodes_serialization(serialization_settings): + @task + def t1(a: int) -> int: + return a + 1 + + arraynode_maptask = map_task(t1, run_all_sub_nodes=True) + + @workflow + def wf(x: typing.List[int]): + return arraynode_maptask(a=x) + + od = OrderedDict() + wf_spec = get_serializable(od, serialization_settings, wf) + + array_node = wf_spec.template.nodes[0] + assert array_node.array_node._run_all_sub_nodes is True + + # Verify it serializes to the protobuf correctly + pb = array_node.array_node.to_flyte_idl() + assert pb.run_all_sub_nodes is True + + +def test_run_all_sub_nodes_serialization_default(serialization_settings): + @task + def t1(a: int) -> int: + return a + 1 + + arraynode_maptask = map_task(t1) + + @workflow + def wf(x: typing.List[int]): + return arraynode_maptask(a=x) + + od = OrderedDict() + wf_spec = get_serializable(od, serialization_settings, wf) + + array_node = wf_spec.template.nodes[0] + assert array_node.array_node._run_all_sub_nodes is False + + pb = array_node.array_node.to_flyte_idl() + assert pb.run_all_sub_nodes is False + + def test_serialization_extended_resources(serialization_settings): @task( accelerator=GPUAccelerator("test_gpu"),