Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
min_successes: Optional[int] = None,
min_success_ratio: Optional[float] = None,
metadata: Optional[_workflow_model.NodeMetadata] = None,
run_all_sub_nodes: bool = False,
):
"""
:param target: The target Flyte entity to map over
Expand All @@ -52,11 +53,13 @@ def __init__(
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions.
:param metadata: The metadata for the underlying node
:param run_all_sub_nodes: If True, all sub-nodes will run to completion even after the failure threshold is met
"""
from flytekit.remote import FlyteLaunchPlan

self.target = target
self._concurrency = concurrency
self._run_all_sub_nodes = run_all_sub_nodes
self.id = target.name
self._bindings = bindings or []
self.metadata = metadata
Expand Down Expand Up @@ -226,6 +229,10 @@ def min_successes(self) -> Optional[int]:
def concurrency(self) -> Optional[int]:
return self._concurrency

@property
def run_all_sub_nodes(self) -> bool:
return self._run_all_sub_nodes

@property
def execution_mode(self) -> _core_workflow.ArrayNode.ExecutionMode:
return self._execution_mode
Expand Down Expand Up @@ -275,6 +282,7 @@ def array_node(
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
min_successes: Optional[int] = None,
run_all_sub_nodes: bool = False,
):
"""
ArrayNode implementation that maps over tasks and other Flyte entities
Expand All @@ -287,6 +295,7 @@ def array_node(
:param min_successes: The minimum number of successful executions. If set, this takes precedence over
min_success_ratio
:param min_success_ratio: The minimum ratio of successful executions
:param run_all_sub_nodes: If True, all sub-nodes will run to completion even after the failure threshold is met
:return: A callable function that takes in keyword arguments and returns a Promise created by
flyte_entity_call_handler
"""
Expand All @@ -300,6 +309,7 @@ def array_node(
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
run_all_sub_nodes=run_all_sub_nodes,
)

return node
21 changes: 20 additions & 1 deletion flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
min_success_ratio: Optional[float] = None,
bound_inputs: Optional[Set[str]] = None,
bound_inputs_values: Optional[Dict[str, Any]] = None,
run_all_sub_nodes: bool = False,
**kwargs,
):
"""
Expand All @@ -51,6 +52,7 @@ def __init__(
:param min_success_ratio: The minimum ratio of successful executions
:param bound_inputs: The set of inputs that should be bound to the map task
:param bound_inputs_values: Inputs that are bound to the array node and will not be mapped over
:param run_all_sub_nodes: If True, all sub-nodes will run to completion even after the failure threshold is met
:param kwargs: Additional keyword arguments to pass to the base class
"""
self._partial = None
Expand Down Expand Up @@ -113,6 +115,7 @@ def __init__(
self._concurrency: Optional[int] = concurrency
self._min_successes: Optional[int] = min_successes
self._min_success_ratio: Optional[float] = min_success_ratio
self._run_all_sub_nodes: bool = run_all_sub_nodes
self._collection_interface = collection_interface

self._execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE
Expand Down Expand Up @@ -168,6 +171,10 @@ def min_successes(self) -> Optional[int]:
def concurrency(self) -> Optional[int]:
return self._concurrency

@property
def run_all_sub_nodes(self) -> bool:
return self._run_all_sub_nodes

@property
def python_function_task(self) -> Union[PythonFunctionTask, PythonInstanceTask]:
return self._run_task
Expand Down Expand Up @@ -385,6 +392,7 @@ def map_task(
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
run_all_sub_nodes: bool = False,
**kwargs,
):
"""
Expand All @@ -398,6 +406,7 @@ def map_task(
array node will inherit parallelism from the workflow
:param min_successes: The minimum number of successful executions
:param min_success_ratio: The minimum ratio of successful executions
:param run_all_sub_nodes: If True, all sub-nodes will run to completion even after the failure threshold is met
"""
from flytekit.remote import FlyteLaunchPlan

Expand All @@ -407,12 +416,14 @@ def map_task(
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
run_all_sub_nodes=run_all_sub_nodes,
)
return array_node_map_task(
task_function=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
run_all_sub_nodes=run_all_sub_nodes,
**kwargs,
)

Expand All @@ -422,6 +433,7 @@ def array_node_map_task(
concurrency: Optional[int] = None,
# TODO why no min_successes?
min_success_ratio: float = 1.0,
run_all_sub_nodes: bool = False,
**kwargs,
):
"""Map task that uses the ``ArrayNode`` construct..
Expand All @@ -437,8 +449,15 @@ def array_node_map_task(
array node will inherit parallelism from the workflow
:param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete
successfully before terminating this task and marking it successful.
:param run_all_sub_nodes: If True, all sub-nodes will run to completion even after the failure threshold is met.
"""
return ArrayNodeMapTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs)
return ArrayNodeMapTask(
task_function,
concurrency=concurrency,
min_success_ratio=min_success_ratio,
run_all_sub_nodes=run_all_sub_nodes,
**kwargs,
)


class ArrayNodeMapTaskResolver(tracker.TrackedInstance, TaskResolverMixin):
Expand Down
4 changes: 4 additions & 0 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def __init__(
is_original_sub_node_interface=False,
data_mode=None,
bound_inputs=None,
run_all_sub_nodes: bool = False,
) -> None:
"""
TODO: docstring
Expand All @@ -409,6 +410,7 @@ def __init__(
self._is_original_sub_node_interface = is_original_sub_node_interface
self._data_mode = data_mode
self._bound_inputs = bound_inputs
self._run_all_sub_nodes = run_all_sub_nodes

@property
def node(self) -> "Node":
Expand All @@ -424,6 +426,7 @@ def to_flyte_idl(self) -> _core_workflow.ArrayNode:
is_original_sub_node_interface=BoolValue(value=self._is_original_sub_node_interface),
data_mode=self._data_mode,
bound_inputs=sorted(self._bound_inputs) if self._bound_inputs else None,
run_all_sub_nodes=self._run_all_sub_nodes,
)

@classmethod
Expand All @@ -433,6 +436,7 @@ def from_flyte_idl(cls, pb2_object) -> "ArrayNode":
pb2_object.parallelism,
pb2_object.min_successes,
pb2_object.min_success_ratio,
run_all_sub_nodes=pb2_object.run_all_sub_nodes,
)


Expand Down
2 changes: 2 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ def get_serializable_array_node(
is_original_sub_node_interface=array_node.is_original_sub_node_interface,
data_mode=array_node.data_mode,
bound_inputs=array_node.bound_inputs,
run_all_sub_nodes=array_node.run_all_sub_nodes,
)


Expand Down Expand Up @@ -671,6 +672,7 @@ def get_serializable_array_node_map_task(
execution_mode=entity.execution_mode,
is_original_sub_node_interface=entity.is_original_sub_node_interface,
bound_inputs=entity.bound_inputs,
run_all_sub_nodes=entity.run_all_sub_nodes,
)


Expand Down
44 changes: 44 additions & 0 deletions tests/flytekit/unit/core/test_array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,47 @@ def test_map_task_wrapper():

mapped_lp = map_task(lp)(a=[1, 3, 5], b=[2, 4, 6], c=[7, 8, 9])
assert mapped_lp == [14, 96, 270]


def test_run_all_sub_nodes_default():
node = array_node(lp, concurrency=10, min_success_ratio=0.9)
assert node.run_all_sub_nodes is False


def test_run_all_sub_nodes_set():
node = array_node(lp, concurrency=10, min_success_ratio=0.9, run_all_sub_nodes=True)
assert node.run_all_sub_nodes is True


def test_run_all_sub_nodes_serialization(serialization_settings):
@workflow
def wf_run_all() -> typing.List[int]:
return map_task(lp, concurrency=10, min_success_ratio=0.9, run_all_sub_nodes=True)(
a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9]
)

od = OrderedDict()
wf_spec = get_serializable(od, serialization_settings, wf_run_all)

parent_node = wf_spec.template.nodes[0]
assert parent_node.array_node._run_all_sub_nodes is True

pb = parent_node.array_node.to_flyte_idl()
assert pb.run_all_sub_nodes is True


def test_run_all_sub_nodes_serialization_default(serialization_settings):
@workflow
def wf_no_run_all() -> typing.List[int]:
return map_task(lp, concurrency=10, min_success_ratio=0.9)(
a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9]
)

od = OrderedDict()
wf_spec = get_serializable(od, serialization_settings, wf_no_run_all)

parent_node = wf_spec.template.nodes[0]
assert parent_node.array_node._run_all_sub_nodes is False

pb = parent_node.array_node.to_flyte_idl()
assert pb.run_all_sub_nodes is False
62 changes: 62 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ def wf1(x: typing.List[int]):
assert array_node.array_node._parallelism == 10
assert not array_node.array_node._is_original_sub_node_interface
assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.MINIMAL_STATE
assert not array_node.array_node._run_all_sub_nodes
task_spec = od[arraynode_maptask]
assert task_spec.template.metadata.retries.retries == 2
assert task_spec.template.metadata.interruptible
Expand All @@ -588,6 +589,67 @@ def wf1(x: typing.List[int]):
assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.FULL_STATE


def test_run_all_sub_nodes_default():
@task
def t1(a: int) -> int:
return a + 1

mt = map_task(t1)
assert mt.run_all_sub_nodes is False


def test_run_all_sub_nodes_set():
@task
def t1(a: int) -> int:
return a + 1

mt = map_task(t1, run_all_sub_nodes=True)
assert mt.run_all_sub_nodes is True


def test_run_all_sub_nodes_serialization(serialization_settings):
@task
def t1(a: int) -> int:
return a + 1

arraynode_maptask = map_task(t1, run_all_sub_nodes=True)

@workflow
def wf(x: typing.List[int]):
return arraynode_maptask(a=x)

od = OrderedDict()
wf_spec = get_serializable(od, serialization_settings, wf)

array_node = wf_spec.template.nodes[0]
assert array_node.array_node._run_all_sub_nodes is True

# Verify it serializes to the protobuf correctly
pb = array_node.array_node.to_flyte_idl()
assert pb.run_all_sub_nodes is True


def test_run_all_sub_nodes_serialization_default(serialization_settings):
@task
def t1(a: int) -> int:
return a + 1

arraynode_maptask = map_task(t1)

@workflow
def wf(x: typing.List[int]):
return arraynode_maptask(a=x)

od = OrderedDict()
wf_spec = get_serializable(od, serialization_settings, wf)

array_node = wf_spec.template.nodes[0]
assert array_node.array_node._run_all_sub_nodes is False

pb = array_node.array_node.to_flyte_idl()
assert pb.run_all_sub_nodes is False


def test_serialization_extended_resources(serialization_settings):
@task(
accelerator=GPUAccelerator("test_gpu"),
Expand Down
Loading