From 59d0aec7d78eecaf13aa4f7e0518bd955452060f Mon Sep 17 00:00:00 2001 From: redartera Date: Tue, 15 Apr 2025 12:40:12 +0000 Subject: [PATCH 1/6] add support for default task resource overrides Signed-off-by: redartera --- flytekit/clis/sdk_in_container/register.py | 14 ++- flytekit/clis/sdk_in_container/run.py | 15 +++ flytekit/configuration/__init__.py | 7 ++ flytekit/core/python_auto_container.py | 11 ++ flytekit/interaction/click_types.py | 35 ++++++ flytekit/remote/remote.py | 4 + flytekit/tools/repo.py | 3 + .../integration/remote/test_remote.py | 98 +++++++++++++++++ .../unit/core/test_context_manager.py | 4 +- .../flytekit/unit/core/test_serialization.py | 101 +++++++++++++++++- .../unit/interaction/test_click_types.py | 22 ++++ 11 files changed, 311 insertions(+), 3 deletions(-) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 57a9b58448..422f07b8bd 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -15,7 +15,8 @@ from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages from flytekit.constants import CopyFileDetection -from flytekit.interaction.click_types import key_value_callback +from flytekit.core.resources import ResourceSpec +from flytekit.interaction.click_types import key_value_callback, resource_spec_callback from flytekit.loggers import logger from flytekit.tools import repo @@ -134,6 +135,15 @@ callback=key_value_callback, help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`", ) +@click.option( + "--default-resources", + required=False, + type=str, + callback=resource_spec_callback, + help="Override default task resource requests and limits for tasks that have no statically defined resource request and limit. " + """Example usage: --default-resources 'cpu=1;mem=2Gi;gpu=1' for requests only or """ + """--default-resources 'cpu=(0.5,1);mem=(2Gi,4Gi);gpu=1' to specify both requests and limits""", +) @click.option( "--skip-errors", "--skip-error", @@ -161,6 +171,7 @@ def register( dry_run: bool, activate_launchplans: bool, env: typing.Optional[typing.Dict[str, str]], + default_resources: typing.Optional[ResourceSpec], skip_errors: bool, ): """ @@ -225,6 +236,7 @@ def register( package_or_module=package_or_module, remote=remote, env=env, + default_resources=default_resources, dry_run=dry_run, activate_launchplans=activate_launchplans, skip_errors=skip_errors, diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 7a08ef31af..d1b59cdfdf 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -43,6 +43,7 @@ from flytekit.core.artifact import ArtifactQuery from flytekit.core.base_task import PythonTask from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.resources import ResourceSpec from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.exceptions.system import FlyteSystemException @@ -51,6 +52,7 @@ FlyteLiteralConverter, key_value_callback, labels_callback, + resource_spec_callback, ) from flytekit.interaction.string_literals import literal_string_repr from flytekit.loggers import logger @@ -197,6 +199,18 @@ class RunLevelParams(PyFlyteParams): help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`", ) ) + default_resources: typing.Optional[ResourceSpec] = make_click_option_field( + click.Option( + param_decls=["--default-resources"], + required=False, + show_default=True, + type=str, + callback=resource_spec_callback, + help="During fast registration, will override default task resource requests and limits for tasks that have no statically defined resource request and limit. " + """Example usage: --default-resources 'cpu=1;mem=2Gi;gpu=1' for requests only or """ + """--default-resources 'cpu=(0.5,1);mem=(2Gi,4Gi);gpu=1' to specify both requests and limits""", + ) + ) tags: typing.List[str] = make_click_option_field( click.Option( param_decls=["--tags", "--tag"], @@ -745,6 +759,7 @@ def _run(*args, **kwargs): source_path=run_level_params.computed_params.project_root, module_name=run_level_params.computed_params.module, fast_package_options=fast_package_options, + default_resources=run_level_params.default_resources, ) run_remote( diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index bee6feb8be..76878c08e1 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -148,6 +148,7 @@ from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists +from flytekit.core.resources import ResourceSpec from flytekit.image_spec import ImageSpec from flytekit.image_spec.image_spec import ImageBuildEngine from flytekit.loggers import logger @@ -824,6 +825,8 @@ class SerializationSettings(DataClassJsonMixin): version (str): The version (if any) with which to register entities under. image_config (ImageConfig): The image config used to define task container images. env (Optional[Dict[str, str]]): Environment variables injected into task container definitions. + default_resources (Optional[ResourceSpec]): The resources to request for the task - this is useful + if users need to override the default resource spec of an entity at registration time. flytekit_virtualenv_root (Optional[str]): During out of container serialize the absolute path of the flytekit virtualenv at serialization time won't match the in-container value at execution time. This optional value is used to provide the in-container virtualenv path @@ -842,6 +845,7 @@ class SerializationSettings(DataClassJsonMixin): domain: typing.Optional[str] = None version: typing.Optional[str] = None env: Optional[Dict[str, str]] = None + default_resources: Optional[ResourceSpec] = None git_repo: Optional[str] = None python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER flytekit_virtualenv_root: Optional[str] = None @@ -916,6 +920,7 @@ def new_builder(self) -> Builder: version=self.version, image_config=self.image_config, env=self.env.copy() if self.env else None, + default_resources=self.default_resources, git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, @@ -967,6 +972,7 @@ class Builder(object): version: str image_config: ImageConfig env: Optional[Dict[str, str]] = None + default_resources: Optional[ResourceSpec] = None git_repo: Optional[str] = None flytekit_virtualenv_root: Optional[str] = None python_interpreter: Optional[str] = None @@ -984,6 +990,7 @@ def build(self) -> SerializationSettings: version=self.version, image_config=self.image_config, env=self.env, + default_resources=self.default_resources, git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index aa0327299b..b439ed1887 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -231,6 +231,17 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain for elem in (settings.env, self.environment): if elem: env.update(elem) + # Override the task's resource spec if it was not set statically in the task definition + + def _resources_unspecified(resources: ResourceSpec) -> bool: + return resources == ResourceSpec( + requests=Resources(), + limits=Resources(), + ) + + if isinstance(settings.default_resources, ResourceSpec) and _resources_unspecified(self.resources): + self._resources = settings.default_resources + return _get_container_definition( image=self.get_image(settings), resource_spec=self.resources, diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index f415367ebc..d5f7278d82 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -20,6 +20,7 @@ from flytekit import BlobType, FlyteContext, Literal, LiteralType, StructuredDataset from flytekit.core.artifact import ArtifactQuery from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.type_engine import TypeEngine from flytekit.models.types import SimpleType from flytekit.remote.remote_fs import FlytePathResolver @@ -81,6 +82,40 @@ def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typi return result +def resource_spec_callback(_: typing.Any, param: str, value: typing.Optional[str]) -> typing.Optional[ResourceSpec]: + """ + Callback for click to parse a resource spec. + """ + if not value: + return None + + def _extract_pair(s: str) -> typing.Optional[typing.Tuple[str, str]]: + """Can extract the pair of values "0.5" and "1" from the string '(0.5,1)'""" + vals = s.strip("() ").split(",") + if len(vals) != 2: + return None + return vals[0].strip(), vals[1].strip() + + items = value.split(";") + _allowed_keys = Resources.__annotations__.keys() + result = {} + for item in items: + kv_split = item.split("=") + if len(kv_split) != 2: + raise click.BadParameter( + f"Expected semicolon separated key-value pairs of the form 'key1=value1;key2=value2;...', got '{item}'" + ) + k = kv_split[0].strip() + v = kv_split[1].strip() + if k not in _allowed_keys: + raise click.BadParameter(f"Expected key to be one of {list(_allowed_keys)}, got '{k}'") + if k in result: + raise click.BadParameter(f"Expected unique keys {list(_allowed_keys)}, got '{k}' multiple times") + result[k.strip()] = _extract_pair(v) or v + + return ResourceSpec.from_multiple_resource(Resources(**result)) + + class DirParamType(click.ParamType): name = "directory path" diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 6c2b065792..fea7b4b79f 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -56,6 +56,7 @@ ) from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec +from flytekit.core.resources import ResourceSpec from flytekit.core.task import ReferenceTask from flytekit.core.tracker import extract_task_module from flytekit.core.type_engine import LiteralsResolver, TypeEngine, strict_type_hint_matching @@ -1310,6 +1311,7 @@ def register_script( source_path: typing.Optional[str] = None, module_name: typing.Optional[str] = None, envs: typing.Optional[typing.Dict[str, str]] = None, + default_resources: typing.Optional[ResourceSpec] = None, fast_package_options: typing.Optional[FastPackageOptions] = None, ) -> typing.Union[FlyteWorkflow, FlyteTask, FlyteLaunchPlan, ReferenceEntity]: """ @@ -1326,6 +1328,7 @@ def register_script( :param source_path: The root of the project path :param module_name: the name of the module :param envs: Environment variables to be passed to the serialization + :param default_resources: Default resources to be passed to the serialization. These override the resource spec for any tasks that have no statically defined resource requests and limits. :param fast_package_options: Options to customize copy_all behavior, ignored when copy_all is False. :return: """ @@ -1364,6 +1367,7 @@ def register_script( image_config=image_config, git_repo=_get_git_repo_url(source_path), env=envs, + default_resources=default_resources, fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index e2e46f49d3..bb73cd9398 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -12,6 +12,7 @@ from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings from flytekit.constants import CopyFileDetection from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.resources import ResourceSpec from flytekit.loggers import logger from flytekit.models import launch_plan, task from flytekit.models.core.identifier import Identifier @@ -251,6 +252,7 @@ def register( remote: FlyteRemote, copy_style: CopyFileDetection, env: typing.Optional[typing.Dict[str, str]], + default_resources: typing.Optional[ResourceSpec], dry_run: bool = False, activate_launchplans: bool = False, skip_errors: bool = False, @@ -273,6 +275,7 @@ def register( image_config=image_config, fast_serialization_settings=None, # should probably add incomplete fast settings env=env, + default_resources=default_resources, ) if not version and copy_style == CopyFileDetection.NO_COPY: diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index c5e908beaa..3decba7928 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -17,12 +17,15 @@ import pytest from unittest import mock from dataclasses import dataclass +import random +import string from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow from flytekit.configuration import Config, ImageConfig, SerializationSettings from flytekit.core.launch_plan import reference_launch_plan from flytekit.core.task import reference_task from flytekit.core.workflow import reference_workflow +from flytekit.models import task as task_models from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task from flytekit.remote.remote import FlyteRemote @@ -1170,3 +1173,98 @@ def test_register_wf_twice(register): ] ) assert out.returncode == 0 + + +def test_register_wf_with_default_resources_override(register): + # Save the version here to retrieve the created task later + version = str(uuid.uuid4()) + # Register the workflow with overridden default resources + out = subprocess.run( + [ + "pyflyte", + "--verbose", + "-c", + CONFIG, + "register", + "--default-resources", + "cpu=1300m;mem=1100Mi", + "--image", + IMAGE, + "--project", + PROJECT, + "--domain", + DOMAIN, + "--version", + version, + MODULE_PATH / "hello_world.py", + ] + ) + assert out.returncode == 0 + + # Retrieve the created task + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + task = remote.fetch_task(name="basic.hello_world.say_hello", version=version) + assert task.template.container is not None + assert task.template.container.resources == task_models.Resources( + requests=[ + task_models.Resources.ResourceEntry( + name=task_models.Resources.ResourceName.CPU, + value="1300m", + ), + task_models.Resources.ResourceEntry( + name=task_models.Resources.ResourceName.MEMORY, + value="1100Mi", + ), + ], + limits=[], + ) + + +def test_run_wf_with_default_resources_override(register): + # Save the execution id here to retrieve the created execution later + prefix = random.choice(string.ascii_lowercase) + short_random_part = uuid.uuid4().hex[:8] + execution_id = f"{prefix}{short_random_part}" + # Register the workflow with overridden default resources + out = subprocess.run( + [ + "pyflyte", + "--verbose", + "-c", + CONFIG, + "run", + "--remote", + "--default-resources", + "cpu=500m;mem=1Gi", + "--project", + PROJECT, + "--domain", + DOMAIN, + "--name", + execution_id, + MODULE_PATH / "hello_world.py", + "my_wf" + ] + ) + assert out.returncode == 0 + + # Retrieve the created task + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution) + version = execution.spec.launch_plan.version + task = remote.fetch_task(name="basic.hello_world.say_hello", version=version) + assert task.template.container is not None + assert task.template.container.resources == task_models.Resources( + requests=[ + task_models.Resources.ResourceEntry( + name=task_models.Resources.ResourceName.CPU, + value="500m", + ), + task_models.Resources.ResourceEntry( + name=task_models.Resources.ResourceName.MEMORY, + value="1Gi", + ), + ], + limits=[], + ) diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index b2702c0a75..8fde709979 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -19,6 +19,7 @@ ) from flytekit.core import mock_stats, context_manager from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, SecretsManager +from flytekit.core.resources import ResourceSpec, Resources from flytekit.models.core import identifier as id_models @@ -301,6 +302,7 @@ def test_serialization_settings_transport(): domain="domain", version="version", env={"hello": "blah"}, + default_resources=ResourceSpec(requests=Resources(cpu="1", mem="2Gi"), limits=Resources(cpu="1", mem="2Gi")), image_config=ImageConfig( default_image=default_img, images=[default_img], @@ -322,7 +324,7 @@ def test_serialization_settings_transport(): ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings - assert len(tp) == 408 + assert len(tp) == 480 def test_exec_params(): diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 833a6d74ab..f112685080 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -1,6 +1,7 @@ import re import os import typing +import dataclasses from collections import OrderedDict import mock @@ -11,7 +12,10 @@ from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.condition import conditional from flytekit.core.python_auto_container import get_registerable_container_image -from flytekit.core.task import task +from flytekit.core.resources import Resources, ResourceSpec +from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.array_node_map_task import map_task +from flytekit.core.task import eager, task from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion, FlyteMissingTypeException from flytekit.image_spec.image_spec import ImageBuildEngine @@ -28,6 +32,7 @@ Void, ) from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType +from flytekit.models import task as task_models from flytekit.tools.translator import get_serializable from flytekit.types.error.error import FlyteError @@ -1165,3 +1170,97 @@ def t1(a: int) -> int: with pytest.raises(AssertionError, match="Got multiple values for argument"): t1(1, a=2) + + +def test_default_resources_override_resourceless_tasks(): + """Tests that default resources specified in serialization settings affect tasks where no resources are specified.""" + + _settings = dataclasses.replace( + serialization_settings, + default_resources=ResourceSpec(requests=Resources(cpu="2", mem="4Gi"), limits=Resources(cpu="2", mem="4Gi")) + ) + + expected_default_resources = task_models.Resources( + requests=[ + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value="2"), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value="4Gi"), + ], + limits=[ + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value="2"), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value="4Gi"), + ], + ) + + + # We test against the 4 fundamental constructs of tasks: @task, @dynamic, @eager, and map_task + + @task + def t1(a: int) -> int: + return a + t1_spec = get_serializable(OrderedDict(), _settings, t1) + assert t1_spec.template.container.resources == expected_default_resources + + @dynamic + def t1_dynamic(a: int) -> int: + return a + t1_dynamic_spec = get_serializable(OrderedDict(), _settings, t1_dynamic) + assert t1_dynamic_spec.template.container.resources == expected_default_resources + + @eager + async def t1_eager(a: int) -> int: + return a + + t1_eager_spec = get_serializable(OrderedDict(), _settings, t1_eager) + assert t1_eager_spec.template.container.resources == expected_default_resources + + t1_map_task = map_task(t1) + t1_map_task_spec = get_serializable(OrderedDict(), _settings, t1_map_task) + assert t1_map_task_spec.template.container.resources == expected_default_resources + + +def test_default_resources_do_not_overriden_tasks_with_explicit_resources(): + """Tests that default resources specified in serialization settings do not override resources specified in task decorators.""" + + _settings = dataclasses.replace( + serialization_settings, + default_resources=ResourceSpec(requests=Resources(cpu="2", mem="4Gi"), limits=Resources(cpu="2", mem="4Gi")) + ) + + @task(requests=Resources(cpu="1", mem="2Gi")) + def t1(a: int) -> int: + return a + + t1_spec = get_serializable(OrderedDict(), _settings, t1) + assert t1_spec.template.container.resources == task_models.Resources( + requests=[ + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value="1"), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value="2Gi"), + ], + limits=[] + ) + + @dynamic(limits=Resources(cpu="1", mem="2Gi")) + def t1_dynamic(a: int) -> int: + return a + + t1_dynamic_spec = get_serializable(OrderedDict(), _settings, t1_dynamic) + assert t1_dynamic_spec.template.container.resources == task_models.Resources( + requests=[], + limits=[ + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value="1"), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value="2Gi"), + ] + ) + + @eager(requests=Resources(cpu="1", mem="4Gi")) + async def t1_eager(a: int) -> int: + return a + + t1_eager_spec = get_serializable(OrderedDict(), _settings, t1_eager) + assert t1_eager_spec.template.container.resources == task_models.Resources( + requests=[ + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value="1"), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value="4Gi"), + ], + limits=[] + ) diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index 33e0bbf78f..92c46946d2 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -15,6 +15,7 @@ from flytekit import FlyteContextManager from flytekit.core.artifact import Artifact +from flytekit.core.resources import ResourceSpec, Resources from flytekit.core.type_engine import TypeEngine from flytekit.interaction.click_types import ( DateTimeType, @@ -28,6 +29,7 @@ StructuredDatasetParamType, UnionParamType, key_value_callback, + resource_spec_callback, ) dummy_param = click.Option(["--dummy"], type=click.STRING, default="dummy") @@ -236,6 +238,26 @@ def test_key_value_callback(): key_value_callback(ctx, "a", ["a=b", "c=d", "e=f", "g"]) +def test_resource_spec_callback(): + ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + assert resource_spec_callback(ctx, "a", None) is None + assert resource_spec_callback(ctx, "a", "cpu=1;mem=2Gi") == ( + ResourceSpec(requests=Resources(cpu="1", mem="2Gi"), limits=Resources()) + ) + assert resource_spec_callback(ctx, "a", "cpu=1;mem=2Gi;gpu=1") == ( + ResourceSpec(requests=Resources(cpu="1", mem="2Gi", gpu="1"), limits=Resources()) + ) + assert resource_spec_callback(ctx, "a", "cpu=(0.5,1);mem=(2Gi,4Gi);gpu=1;ephemeral_storage=(20Gi,30Gi)") == ( + ResourceSpec(requests=Resources(cpu="0.5", mem="2Gi", gpu="1", ephemeral_storage="20Gi"), limits=Resources(cpu="1", mem="4Gi", gpu=None, ephemeral_storage="30Gi")) + ) + with pytest.raises(click.BadParameter, match="Expected semicolon"): + resource_spec_callback(ctx, "a", "cpu=1,mem=2Gi") + with pytest.raises(click.BadParameter, match="Expected key to be one of"): + resource_spec_callback(ctx, "a", "cpu=1;a=b;mem=2Gi") + with pytest.raises(click.BadParameter, match="Expected unique keys"): + resource_spec_callback(ctx, "a", "cpu=1;mem=2Gi;cpu=1") + + @pytest.mark.parametrize( "param_type", [ From e25d8d41102c2a914a6f81e048f4a0413d18071b Mon Sep 17 00:00:00 2001 From: redartera Date: Sun, 4 May 2025 11:26:47 +0000 Subject: [PATCH 2/6] apply lint Signed-off-by: redartera --- flytekit/core/python_auto_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 2ac03aa12a..9ba8ae6627 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -232,7 +232,7 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain for elem in (settings.env, self.environment): if elem: env.update(elem) - + # Override the task's resource spec if it was not set statically in the task definition def _resources_unspecified(resources: ResourceSpec) -> bool: From 622dcdd56036269934c3788e3d9fef632d0e3ccc Mon Sep 17 00:00:00 2001 From: redartera Date: Mon, 2 Jun 2025 13:40:27 +0000 Subject: [PATCH 3/6] separate requests and limits per review comment Signed-off-by: redartera --- flytekit/clis/sdk_in_container/register.py | 28 +++++++++++------ flytekit/clis/sdk_in_container/run.py | 31 +++++++++++++------ flytekit/interaction/click_types.py | 25 ++++++--------- .../integration/remote/test_remote.py | 12 +++---- .../unit/interaction/test_click_types.py | 25 ++++++--------- 5 files changed, 65 insertions(+), 56 deletions(-) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 422f07b8bd..795b1acd24 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -15,8 +15,8 @@ from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages from flytekit.constants import CopyFileDetection -from flytekit.core.resources import ResourceSpec -from flytekit.interaction.click_types import key_value_callback, resource_spec_callback +from flytekit.core.resources import Resources, ResourceSpec +from flytekit.interaction.click_types import key_value_callback, resource_callback from flytekit.loggers import logger from flytekit.tools import repo @@ -136,13 +136,20 @@ help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`", ) @click.option( - "--default-resources", + "--resource-requests", required=False, type=str, - callback=resource_spec_callback, - help="Override default task resource requests and limits for tasks that have no statically defined resource request and limit. " - """Example usage: --default-resources 'cpu=1;mem=2Gi;gpu=1' for requests only or """ - """--default-resources 'cpu=(0.5,1);mem=(2Gi,4Gi);gpu=1' to specify both requests and limits""", + callback=resource_callback, + help="Override default task resource requests for tasks that have no statically defined resource requests in their task decorator. " + """Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'""", +) +@click.option( + "--resource-limits", + required=False, + type=str, + callback=resource_callback, + help="Override default task resource limits for tasks that have no statically defined resource limits in their task decorator. " + """Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'""", ) @click.option( "--skip-errors", @@ -171,7 +178,8 @@ def register( dry_run: bool, activate_launchplans: bool, env: typing.Optional[typing.Dict[str, str]], - default_resources: typing.Optional[ResourceSpec], + resource_requests: typing.Optional[Resources], + resource_limits: typing.Optional[Resources], skip_errors: bool, ): """ @@ -236,7 +244,9 @@ def register( package_or_module=package_or_module, remote=remote, env=env, - default_resources=default_resources, + default_resources=ResourceSpec( + requests=resource_requests or Resources(), limits=resource_limits or Resources() + ), dry_run=dry_run, activate_launchplans=activate_launchplans, skip_errors=skip_errors, diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 3f69cd4363..b2b8417422 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -43,7 +43,7 @@ from flytekit.core.artifact import ArtifactQuery from flytekit.core.base_task import PythonTask from flytekit.core.data_persistence import FileAccessProvider -from flytekit.core.resources import ResourceSpec +from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.exceptions.system import FlyteSystemException @@ -52,7 +52,7 @@ FlyteLiteralConverter, key_value_callback, labels_callback, - resource_spec_callback, + resource_callback, ) from flytekit.interaction.string_literals import literal_string_repr from flytekit.loggers import logger @@ -210,16 +210,26 @@ class RunLevelParams(PyFlyteParams): help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`", ) ) - default_resources: typing.Optional[ResourceSpec] = make_click_option_field( + resource_requests: typing.Optional[Resources] = make_click_option_field( click.Option( - param_decls=["--default-resources"], + param_decls=["--resource-requests"], required=False, show_default=True, type=str, - callback=resource_spec_callback, - help="During fast registration, will override default task resource requests and limits for tasks that have no statically defined resource request and limit. " - """Example usage: --default-resources 'cpu=1;mem=2Gi;gpu=1' for requests only or """ - """--default-resources 'cpu=(0.5,1);mem=(2Gi,4Gi);gpu=1' to specify both requests and limits""", + callback=resource_callback, + help="This overrides default task resource requests for tasks that have no statically defined resource requests in their task decorator. " + """Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'""", + ) + ) + resource_limits: typing.Optional[Resources] = make_click_option_field( + click.Option( + param_decls=["--resource-limits"], + required=False, + show_default=True, + type=str, + callback=resource_callback, + help="This overrides default task resource limits for tasks that have no statically defined resource limits in their task decorator. " + """Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'""", ) ) tags: typing.List[str] = make_click_option_field( @@ -771,7 +781,10 @@ def _run(*args, **kwargs): source_path=run_level_params.computed_params.project_root, module_name=run_level_params.computed_params.module, fast_package_options=fast_package_options, - default_resources=run_level_params.default_resources, + default_resources=ResourceSpec( + requests=run_level_params.resource_requests or Resources(), + limits=run_level_params.resource_limits or Resources(), + ), ) run_remote( diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 046fe71566..9f642a8b5d 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -23,7 +23,7 @@ from flytekit import BlobType, FlyteContext, Literal, LiteralType, StructuredDataset from flytekit.core.artifact import ArtifactQuery from flytekit.core.data_persistence import FileAccessProvider -from flytekit.core.resources import Resources, ResourceSpec +from flytekit.core.resources import Resources from flytekit.core.type_engine import TypeEngine from flytekit.models.types import SimpleType from flytekit.remote.remote_fs import FlytePathResolver @@ -85,38 +85,31 @@ def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typi return result -def resource_spec_callback(_: typing.Any, param: str, value: typing.Optional[str]) -> typing.Optional[ResourceSpec]: +def resource_callback(_: typing.Any, param: str, value: typing.Optional[str]) -> typing.Optional[Resources]: """ - Callback for click to parse a resource spec. + Callback for click to parse a resource from a comma-separated string of the form 'cpu=1,mem=2Gi' for example """ if not value: return None - def _extract_pair(s: str) -> typing.Optional[typing.Tuple[str, str]]: - """Can extract the pair of values "0.5" and "1" from the string '(0.5,1)'""" - vals = s.strip("() ").split(",") - if len(vals) != 2: - return None - return vals[0].strip(), vals[1].strip() - - items = value.split(";") + items = value.split(",") _allowed_keys = Resources.__annotations__.keys() result = {} for item in items: kv_split = item.split("=") if len(kv_split) != 2: raise click.BadParameter( - f"Expected semicolon separated key-value pairs of the form 'key1=value1;key2=value2;...', got '{item}'" + f"Expected comma separated key-value pairs of the form 'key1=value1,key2=value2,...', got '{item}'" ) k = kv_split[0].strip() v = kv_split[1].strip() if k not in _allowed_keys: - raise click.BadParameter(f"Expected key to be one of {list(_allowed_keys)}, got '{k}'") + raise click.BadParameter(f"Expected key to be one of {list(_allowed_keys)}, but got '{k}'") if k in result: - raise click.BadParameter(f"Expected unique keys {list(_allowed_keys)}, got '{k}' multiple times") - result[k.strip()] = _extract_pair(v) or v + raise click.BadParameter(f"Expected unique keys {list(_allowed_keys)}, but got '{k}' multiple times") + result[k] = v - return ResourceSpec.from_multiple_resource(Resources(**result)) + return Resources(**result) class DirParamType(click.ParamType): diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 3decba7928..918cae8f3c 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -1175,7 +1175,7 @@ def test_register_wf_twice(register): assert out.returncode == 0 -def test_register_wf_with_default_resources_override(register): +def test_register_wf_with_resource_requests_override(register): # Save the version here to retrieve the created task later version = str(uuid.uuid4()) # Register the workflow with overridden default resources @@ -1186,8 +1186,8 @@ def test_register_wf_with_default_resources_override(register): "-c", CONFIG, "register", - "--default-resources", - "cpu=1300m;mem=1100Mi", + "--resource-requests", + "cpu=1300m,mem=1100Mi", "--image", IMAGE, "--project", @@ -1220,7 +1220,7 @@ def test_register_wf_with_default_resources_override(register): ) -def test_run_wf_with_default_resources_override(register): +def test_run_wf_with_resource_requests_override(register): # Save the execution id here to retrieve the created execution later prefix = random.choice(string.ascii_lowercase) short_random_part = uuid.uuid4().hex[:8] @@ -1234,8 +1234,8 @@ def test_run_wf_with_default_resources_override(register): CONFIG, "run", "--remote", - "--default-resources", - "cpu=500m;mem=1Gi", + "--resource-requests", + "cpu=500m,mem=1Gi", "--project", PROJECT, "--domain", diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index 92c46946d2..4992e5a91d 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -29,7 +29,7 @@ StructuredDatasetParamType, UnionParamType, key_value_callback, - resource_spec_callback, + resource_callback, ) dummy_param = click.Option(["--dummy"], type=click.STRING, default="dummy") @@ -238,24 +238,17 @@ def test_key_value_callback(): key_value_callback(ctx, "a", ["a=b", "c=d", "e=f", "g"]) -def test_resource_spec_callback(): +def test_resource_callback(): ctx = click.Context(click.Command("test_command"), obj={"remote": True}) - assert resource_spec_callback(ctx, "a", None) is None - assert resource_spec_callback(ctx, "a", "cpu=1;mem=2Gi") == ( - ResourceSpec(requests=Resources(cpu="1", mem="2Gi"), limits=Resources()) - ) - assert resource_spec_callback(ctx, "a", "cpu=1;mem=2Gi;gpu=1") == ( - ResourceSpec(requests=Resources(cpu="1", mem="2Gi", gpu="1"), limits=Resources()) - ) - assert resource_spec_callback(ctx, "a", "cpu=(0.5,1);mem=(2Gi,4Gi);gpu=1;ephemeral_storage=(20Gi,30Gi)") == ( - ResourceSpec(requests=Resources(cpu="0.5", mem="2Gi", gpu="1", ephemeral_storage="20Gi"), limits=Resources(cpu="1", mem="4Gi", gpu=None, ephemeral_storage="30Gi")) - ) - with pytest.raises(click.BadParameter, match="Expected semicolon"): - resource_spec_callback(ctx, "a", "cpu=1,mem=2Gi") + assert resource_callback(ctx, "a", None) is None + assert resource_callback(ctx, "a", "cpu=1,mem=2Gi") == Resources(cpu="1", mem="2Gi") + assert resource_callback(ctx, "a", "cpu=1,mem=2Gi,gpu=1") == Resources(cpu="1", mem="2Gi", gpu="1") + with pytest.raises(click.BadParameter, match="Expected comma separated"): + resource_callback(ctx, "a", "cpu=1;mem=2Gi") with pytest.raises(click.BadParameter, match="Expected key to be one of"): - resource_spec_callback(ctx, "a", "cpu=1;a=b;mem=2Gi") + resource_callback(ctx, "a", "cpu=1,a=b,mem=2Gi") with pytest.raises(click.BadParameter, match="Expected unique keys"): - resource_spec_callback(ctx, "a", "cpu=1;mem=2Gi;cpu=1") + resource_callback(ctx, "a", "cpu=1,mem=2Gi,cpu=1") @pytest.mark.parametrize( From a5ca7722d0b5c72a2f7697f1a5d124b3f0217ed5 Mon Sep 17 00:00:00 2001 From: redartera Date: Tue, 3 Jun 2025 21:20:38 +0000 Subject: [PATCH 4/6] address comment Signed-off-by: redartera --- flytekit/clis/sdk_in_container/run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 99310758d6..eeac521e76 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -218,7 +218,7 @@ class RunLevelParams(PyFlyteParams): type=str, callback=resource_callback, help="This overrides default task resource requests for tasks that have no statically defined resource requests in their task decorator. " - """Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'""", + "Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'", ) ) resource_limits: typing.Optional[Resources] = make_click_option_field( @@ -229,7 +229,7 @@ class RunLevelParams(PyFlyteParams): type=str, callback=resource_callback, help="This overrides default task resource limits for tasks that have no statically defined resource limits in their task decorator. " - """Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'""", + "Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'", ) ) tags: typing.List[str] = make_click_option_field( From 5c3cbe4e43f1578fdf90f4e55d569c29382d0a98 Mon Sep 17 00:00:00 2001 From: redartera Date: Tue, 3 Jun 2025 21:21:50 +0000 Subject: [PATCH 5/6] one more Signed-off-by: redartera --- flytekit/clis/sdk_in_container/register.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 795b1acd24..0bf56fed69 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -141,7 +141,7 @@ type=str, callback=resource_callback, help="Override default task resource requests for tasks that have no statically defined resource requests in their task decorator. " - """Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'""", + "Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'", ) @click.option( "--resource-limits", @@ -149,7 +149,7 @@ type=str, callback=resource_callback, help="Override default task resource limits for tasks that have no statically defined resource limits in their task decorator. " - """Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'""", + "Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'", ) @click.option( "--skip-errors", From 968a04173c9ad78fb459f667397064da114b524a Mon Sep 17 00:00:00 2001 From: redartera Date: Thu, 5 Jun 2025 13:11:12 +0000 Subject: [PATCH 6/6] shorter docstring - put resources in variables for tests Signed-off-by: redartera --- flytekit/interaction/click_types.py | 2 +- .../integration/remote/test_remote.py | 20 +++++++---- .../flytekit/unit/core/test_serialization.py | 36 +++++++++++-------- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 9f642a8b5d..7918339da3 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -87,7 +87,7 @@ def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typi def resource_callback(_: typing.Any, param: str, value: typing.Optional[str]) -> typing.Optional[Resources]: """ - Callback for click to parse a resource from a comma-separated string of the form 'cpu=1,mem=2Gi' for example + Click callback to parse resource strings like 'cpu=1,mem=2Gi' into a Resources object """ if not value: return None diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 16004cbc48..18481d9e69 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -1260,6 +1260,10 @@ def test_register_wf_twice(register): def test_register_wf_with_resource_requests_override(register): # Save the version here to retrieve the created task later version = str(uuid.uuid4()) + + cpu = "1300m" + mem = "1100Mi" + # Register the workflow with overridden default resources out = subprocess.run( [ @@ -1269,7 +1273,7 @@ def test_register_wf_with_resource_requests_override(register): CONFIG, "register", "--resource-requests", - "cpu=1300m,mem=1100Mi", + f"cpu={cpu},mem={mem}", "--image", IMAGE, "--project", @@ -1291,11 +1295,11 @@ def test_register_wf_with_resource_requests_override(register): requests=[ task_models.Resources.ResourceEntry( name=task_models.Resources.ResourceName.CPU, - value="1300m", + value=cpu, ), task_models.Resources.ResourceEntry( name=task_models.Resources.ResourceName.MEMORY, - value="1100Mi", + value=mem, ), ], limits=[], @@ -1307,6 +1311,10 @@ def test_run_wf_with_resource_requests_override(register): prefix = random.choice(string.ascii_lowercase) short_random_part = uuid.uuid4().hex[:8] execution_id = f"{prefix}{short_random_part}" + + cpu = "500m" + mem = "1Gi" + # Register the workflow with overridden default resources out = subprocess.run( [ @@ -1317,7 +1325,7 @@ def test_run_wf_with_resource_requests_override(register): "run", "--remote", "--resource-requests", - "cpu=500m,mem=1Gi", + f"cpu={cpu},mem={mem}", "--project", PROJECT, "--domain", @@ -1341,11 +1349,11 @@ def test_run_wf_with_resource_requests_override(register): requests=[ task_models.Resources.ResourceEntry( name=task_models.Resources.ResourceName.CPU, - value="500m", + value=cpu, ), task_models.Resources.ResourceEntry( name=task_models.Resources.ResourceName.MEMORY, - value="1Gi", + value=mem, ), ], limits=[], diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index f112685080..429305e931 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -1175,19 +1175,22 @@ def t1(a: int) -> int: def test_default_resources_override_resourceless_tasks(): """Tests that default resources specified in serialization settings affect tasks where no resources are specified.""" + cpu = "2" + mem = "4Gi" + _settings = dataclasses.replace( serialization_settings, - default_resources=ResourceSpec(requests=Resources(cpu="2", mem="4Gi"), limits=Resources(cpu="2", mem="4Gi")) + default_resources=ResourceSpec(requests=Resources(cpu=cpu, mem=mem), limits=Resources(cpu=cpu, mem=mem)) ) expected_default_resources = task_models.Resources( requests=[ - task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value="2"), - task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value="4Gi"), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value=cpu), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value=mem), ], limits=[ - task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value="2"), - task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value="4Gi"), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value=cpu), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value=mem), ], ) @@ -1226,20 +1229,25 @@ def test_default_resources_do_not_overriden_tasks_with_explicit_resources(): default_resources=ResourceSpec(requests=Resources(cpu="2", mem="4Gi"), limits=Resources(cpu="2", mem="4Gi")) ) - @task(requests=Resources(cpu="1", mem="2Gi")) + # These cpu/mem values are static i.e. used in task decorators - should not be overridden by + # resources from serialization settings + cpu_static = "1" + mem_static = "2Gi" + + @task(requests=Resources(cpu=cpu_static, mem=mem_static)) def t1(a: int) -> int: return a t1_spec = get_serializable(OrderedDict(), _settings, t1) assert t1_spec.template.container.resources == task_models.Resources( requests=[ - task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value="1"), - task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value="2Gi"), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value=cpu_static), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value=mem_static), ], limits=[] ) - @dynamic(limits=Resources(cpu="1", mem="2Gi")) + @dynamic(limits=Resources(cpu=cpu_static, mem=mem_static)) def t1_dynamic(a: int) -> int: return a @@ -1247,20 +1255,20 @@ def t1_dynamic(a: int) -> int: assert t1_dynamic_spec.template.container.resources == task_models.Resources( requests=[], limits=[ - task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value="1"), - task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value="2Gi"), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value=cpu_static), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value=mem_static), ] ) - @eager(requests=Resources(cpu="1", mem="4Gi")) + @eager(requests=Resources(cpu=cpu_static, mem=mem_static)) async def t1_eager(a: int) -> int: return a t1_eager_spec = get_serializable(OrderedDict(), _settings, t1_eager) assert t1_eager_spec.template.container.resources == task_models.Resources( requests=[ - task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value="1"), - task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value="4Gi"), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value=cpu_static), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value=mem_static), ], limits=[] )