diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 57a9b58448..0bf56fed69 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -15,7 +15,8 @@ from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages from flytekit.constants import CopyFileDetection -from flytekit.interaction.click_types import key_value_callback +from flytekit.core.resources import Resources, ResourceSpec +from flytekit.interaction.click_types import key_value_callback, resource_callback from flytekit.loggers import logger from flytekit.tools import repo @@ -134,6 +135,22 @@ callback=key_value_callback, help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`", ) +@click.option( + "--resource-requests", + required=False, + type=str, + callback=resource_callback, + help="Override default task resource requests for tasks that have no statically defined resource requests in their task decorator. " + "Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'", +) +@click.option( + "--resource-limits", + required=False, + type=str, + callback=resource_callback, + help="Override default task resource limits for tasks that have no statically defined resource limits in their task decorator. " + "Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'", +) @click.option( "--skip-errors", "--skip-error", @@ -161,6 +178,8 @@ def register( dry_run: bool, activate_launchplans: bool, env: typing.Optional[typing.Dict[str, str]], + resource_requests: typing.Optional[Resources], + resource_limits: typing.Optional[Resources], skip_errors: bool, ): """ @@ -225,6 +244,9 @@ def register( package_or_module=package_or_module, remote=remote, env=env, + default_resources=ResourceSpec( + requests=resource_requests or Resources(), limits=resource_limits or Resources() + ), dry_run=dry_run, activate_launchplans=activate_launchplans, skip_errors=skip_errors, diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 92687879fc..eeac521e76 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -43,6 +43,7 @@ from flytekit.core.artifact import ArtifactQuery from flytekit.core.base_task import PythonTask from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.exceptions.system import FlyteSystemException @@ -51,6 +52,7 @@ FlyteLiteralConverter, key_value_callback, labels_callback, + resource_callback, ) from flytekit.interaction.string_literals import literal_string_repr from flytekit.loggers import logger @@ -208,6 +210,28 @@ class RunLevelParams(PyFlyteParams): help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`", ) ) + resource_requests: typing.Optional[Resources] = make_click_option_field( + click.Option( + param_decls=["--resource-requests"], + required=False, + show_default=True, + type=str, + callback=resource_callback, + help="This overrides default task resource requests for tasks that have no statically defined resource requests in their task decorator. " + "Example usage: --resource-requests 'cpu=1,mem=2Gi,gpu=1'", + ) + ) + resource_limits: typing.Optional[Resources] = make_click_option_field( + click.Option( + param_decls=["--resource-limits"], + required=False, + show_default=True, + type=str, + callback=resource_callback, + help="This overrides default task resource limits for tasks that have no statically defined resource limits in their task decorator. " + "Example usage: --resource-limits 'cpu=1,mem=2Gi,gpu=1'", + ) + ) tags: typing.List[str] = make_click_option_field( click.Option( param_decls=["--tags", "--tag"], @@ -756,6 +780,10 @@ def _run(*args, **kwargs): source_path=run_level_params.computed_params.project_root, module_name=run_level_params.computed_params.module, fast_package_options=fast_package_options, + default_resources=ResourceSpec( + requests=run_level_params.resource_requests or Resources(), + limits=run_level_params.resource_limits or Resources(), + ), ) run_remote( diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index d27e07c886..ec21647d06 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -129,6 +129,7 @@ from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists +from flytekit.core.resources import ResourceSpec from flytekit.image_spec import ImageSpec from flytekit.image_spec.image_spec import ImageBuildEngine from flytekit.loggers import logger @@ -805,6 +806,8 @@ class SerializationSettings(DataClassJsonMixin): version (str): The version (if any) with which to register entities under. image_config (ImageConfig): The image config used to define task container images. env (Optional[Dict[str, str]]): Environment variables injected into task container definitions. + default_resources (Optional[ResourceSpec]): The resources to request for the task - this is useful + if users need to override the default resource spec of an entity at registration time. flytekit_virtualenv_root (Optional[str]): During out of container serialize the absolute path of the flytekit virtualenv at serialization time won't match the in-container value at execution time. This optional value is used to provide the in-container virtualenv path @@ -823,6 +826,7 @@ class SerializationSettings(DataClassJsonMixin): domain: typing.Optional[str] = None version: typing.Optional[str] = None env: Optional[Dict[str, str]] = None + default_resources: Optional[ResourceSpec] = None git_repo: Optional[str] = None python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER flytekit_virtualenv_root: Optional[str] = None @@ -897,6 +901,7 @@ def new_builder(self) -> Builder: version=self.version, image_config=self.image_config, env=self.env.copy() if self.env else None, + default_resources=self.default_resources, git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, @@ -948,6 +953,7 @@ class Builder(object): version: str image_config: ImageConfig env: Optional[Dict[str, str]] = None + default_resources: Optional[ResourceSpec] = None git_repo: Optional[str] = None flytekit_virtualenv_root: Optional[str] = None python_interpreter: Optional[str] = None @@ -965,6 +971,7 @@ def build(self) -> SerializationSettings: version=self.version, image_config=self.image_config, env=self.env, + default_resources=self.default_resources, git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index a884b150c0..9ba8ae6627 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -233,6 +233,17 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain if elem: env.update(elem) + # Override the task's resource spec if it was not set statically in the task definition + + def _resources_unspecified(resources: ResourceSpec) -> bool: + return resources == ResourceSpec( + requests=Resources(), + limits=Resources(), + ) + + if isinstance(settings.default_resources, ResourceSpec) and _resources_unspecified(self.resources): + self._resources = settings.default_resources + # Add runtime dependencies into environment if isinstance(self.container_image, ImageSpec) and self.container_image.runtime_packages: runtime_packages = " ".join(self.container_image.runtime_packages) diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 89bc58e117..7918339da3 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -23,6 +23,7 @@ from flytekit import BlobType, FlyteContext, Literal, LiteralType, StructuredDataset from flytekit.core.artifact import ArtifactQuery from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.resources import Resources from flytekit.core.type_engine import TypeEngine from flytekit.models.types import SimpleType from flytekit.remote.remote_fs import FlytePathResolver @@ -84,6 +85,33 @@ def labels_callback(_: typing.Any, param: str, values: typing.List[str]) -> typi return result +def resource_callback(_: typing.Any, param: str, value: typing.Optional[str]) -> typing.Optional[Resources]: + """ + Click callback to parse resource strings like 'cpu=1,mem=2Gi' into a Resources object + """ + if not value: + return None + + items = value.split(",") + _allowed_keys = Resources.__annotations__.keys() + result = {} + for item in items: + kv_split = item.split("=") + if len(kv_split) != 2: + raise click.BadParameter( + f"Expected comma separated key-value pairs of the form 'key1=value1,key2=value2,...', got '{item}'" + ) + k = kv_split[0].strip() + v = kv_split[1].strip() + if k not in _allowed_keys: + raise click.BadParameter(f"Expected key to be one of {list(_allowed_keys)}, but got '{k}'") + if k in result: + raise click.BadParameter(f"Expected unique keys {list(_allowed_keys)}, but got '{k}' multiple times") + result[k] = v + + return Resources(**result) + + class DirParamType(click.ParamType): name = "directory path" diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 98e6381198..b2ba9bcd67 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -56,6 +56,7 @@ ) from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec +from flytekit.core.resources import ResourceSpec from flytekit.core.task import ReferenceTask from flytekit.core.tracker import extract_task_module from flytekit.core.type_engine import LiteralsResolver, TypeEngine, strict_type_hint_matching @@ -1326,6 +1327,7 @@ def register_script( source_path: typing.Optional[str] = None, module_name: typing.Optional[str] = None, envs: typing.Optional[typing.Dict[str, str]] = None, + default_resources: typing.Optional[ResourceSpec] = None, fast_package_options: typing.Optional[FastPackageOptions] = None, ) -> typing.Union[FlyteWorkflow, FlyteTask, FlyteLaunchPlan, ReferenceEntity]: """ @@ -1342,6 +1344,7 @@ def register_script( :param source_path: The root of the project path :param module_name: the name of the module :param envs: Environment variables to be passed to the serialization + :param default_resources: Default resources to be passed to the serialization. These override the resource spec for any tasks that have no statically defined resource requests and limits. :param fast_package_options: Options to customize copy_all behavior, ignored when copy_all is False. :return: """ @@ -1380,6 +1383,7 @@ def register_script( image_config=image_config, git_repo=_get_git_repo_url(source_path), env=envs, + default_resources=default_resources, fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index b2c5cd6632..975afdb445 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -13,6 +13,7 @@ from flytekit.constants import CopyFileDetection from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContextManager, FlyteEntities +from flytekit.core.resources import ResourceSpec from flytekit.loggers import logger from flytekit.models import launch_plan, task from flytekit.models.core.identifier import Identifier @@ -252,6 +253,7 @@ def register( remote: FlyteRemote, copy_style: CopyFileDetection, env: typing.Optional[typing.Dict[str, str]], + default_resources: typing.Optional[ResourceSpec], dry_run: bool = False, activate_launchplans: bool = False, skip_errors: bool = False, @@ -274,6 +276,7 @@ def register( image_config=image_config, fast_serialization_settings=None, # should probably add incomplete fast settings env=env, + default_resources=default_resources, ) if not version and copy_style == CopyFileDetection.NO_COPY: diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 276088534f..18481d9e69 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -16,6 +16,8 @@ import uuid import pytest from unittest import mock +import random +import string from dataclasses import asdict, dataclass from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow @@ -23,6 +25,7 @@ from flytekit.core.launch_plan import reference_launch_plan from flytekit.core.task import reference_task from flytekit.core.workflow import reference_workflow +from flytekit.models import task as task_models from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task from flytekit.remote.remote import FlyteRemote @@ -1252,3 +1255,106 @@ def test_register_wf_twice(register): ] ) assert out.returncode == 0 + + +def test_register_wf_with_resource_requests_override(register): + # Save the version here to retrieve the created task later + version = str(uuid.uuid4()) + + cpu = "1300m" + mem = "1100Mi" + + # Register the workflow with overridden default resources + out = subprocess.run( + [ + "pyflyte", + "--verbose", + "-c", + CONFIG, + "register", + "--resource-requests", + f"cpu={cpu},mem={mem}", + "--image", + IMAGE, + "--project", + PROJECT, + "--domain", + DOMAIN, + "--version", + version, + MODULE_PATH / "hello_world.py", + ] + ) + assert out.returncode == 0 + + # Retrieve the created task + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + task = remote.fetch_task(name="basic.hello_world.say_hello", version=version) + assert task.template.container is not None + assert task.template.container.resources == task_models.Resources( + requests=[ + task_models.Resources.ResourceEntry( + name=task_models.Resources.ResourceName.CPU, + value=cpu, + ), + task_models.Resources.ResourceEntry( + name=task_models.Resources.ResourceName.MEMORY, + value=mem, + ), + ], + limits=[], + ) + + +def test_run_wf_with_resource_requests_override(register): + # Save the execution id here to retrieve the created execution later + prefix = random.choice(string.ascii_lowercase) + short_random_part = uuid.uuid4().hex[:8] + execution_id = f"{prefix}{short_random_part}" + + cpu = "500m" + mem = "1Gi" + + # Register the workflow with overridden default resources + out = subprocess.run( + [ + "pyflyte", + "--verbose", + "-c", + CONFIG, + "run", + "--remote", + "--resource-requests", + f"cpu={cpu},mem={mem}", + "--project", + PROJECT, + "--domain", + DOMAIN, + "--name", + execution_id, + MODULE_PATH / "hello_world.py", + "my_wf" + ] + ) + assert out.returncode == 0 + + # Retrieve the created task + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution) + version = execution.spec.launch_plan.version + task = remote.fetch_task(name="basic.hello_world.say_hello", version=version) + assert task.template.container is not None + assert task.template.container.resources == task_models.Resources( + requests=[ + task_models.Resources.ResourceEntry( + name=task_models.Resources.ResourceName.CPU, + value=cpu, + ), + task_models.Resources.ResourceEntry( + name=task_models.Resources.ResourceName.MEMORY, + value=mem, + ), + ], + limits=[], + ) diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index b2702c0a75..8fde709979 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -19,6 +19,7 @@ ) from flytekit.core import mock_stats, context_manager from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, SecretsManager +from flytekit.core.resources import ResourceSpec, Resources from flytekit.models.core import identifier as id_models @@ -301,6 +302,7 @@ def test_serialization_settings_transport(): domain="domain", version="version", env={"hello": "blah"}, + default_resources=ResourceSpec(requests=Resources(cpu="1", mem="2Gi"), limits=Resources(cpu="1", mem="2Gi")), image_config=ImageConfig( default_image=default_img, images=[default_img], @@ -322,7 +324,7 @@ def test_serialization_settings_transport(): ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings - assert len(tp) == 408 + assert len(tp) == 480 def test_exec_params(): diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 833a6d74ab..429305e931 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -1,6 +1,7 @@ import re import os import typing +import dataclasses from collections import OrderedDict import mock @@ -11,7 +12,10 @@ from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.condition import conditional from flytekit.core.python_auto_container import get_registerable_container_image -from flytekit.core.task import task +from flytekit.core.resources import Resources, ResourceSpec +from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.array_node_map_task import map_task +from flytekit.core.task import eager, task from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion, FlyteMissingTypeException from flytekit.image_spec.image_spec import ImageBuildEngine @@ -28,6 +32,7 @@ Void, ) from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType +from flytekit.models import task as task_models from flytekit.tools.translator import get_serializable from flytekit.types.error.error import FlyteError @@ -1165,3 +1170,105 @@ def t1(a: int) -> int: with pytest.raises(AssertionError, match="Got multiple values for argument"): t1(1, a=2) + + +def test_default_resources_override_resourceless_tasks(): + """Tests that default resources specified in serialization settings affect tasks where no resources are specified.""" + + cpu = "2" + mem = "4Gi" + + _settings = dataclasses.replace( + serialization_settings, + default_resources=ResourceSpec(requests=Resources(cpu=cpu, mem=mem), limits=Resources(cpu=cpu, mem=mem)) + ) + + expected_default_resources = task_models.Resources( + requests=[ + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value=cpu), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value=mem), + ], + limits=[ + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value=cpu), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value=mem), + ], + ) + + + # We test against the 4 fundamental constructs of tasks: @task, @dynamic, @eager, and map_task + + @task + def t1(a: int) -> int: + return a + t1_spec = get_serializable(OrderedDict(), _settings, t1) + assert t1_spec.template.container.resources == expected_default_resources + + @dynamic + def t1_dynamic(a: int) -> int: + return a + t1_dynamic_spec = get_serializable(OrderedDict(), _settings, t1_dynamic) + assert t1_dynamic_spec.template.container.resources == expected_default_resources + + @eager + async def t1_eager(a: int) -> int: + return a + + t1_eager_spec = get_serializable(OrderedDict(), _settings, t1_eager) + assert t1_eager_spec.template.container.resources == expected_default_resources + + t1_map_task = map_task(t1) + t1_map_task_spec = get_serializable(OrderedDict(), _settings, t1_map_task) + assert t1_map_task_spec.template.container.resources == expected_default_resources + + +def test_default_resources_do_not_overriden_tasks_with_explicit_resources(): + """Tests that default resources specified in serialization settings do not override resources specified in task decorators.""" + + _settings = dataclasses.replace( + serialization_settings, + default_resources=ResourceSpec(requests=Resources(cpu="2", mem="4Gi"), limits=Resources(cpu="2", mem="4Gi")) + ) + + # These cpu/mem values are static i.e. used in task decorators - should not be overridden by + # resources from serialization settings + cpu_static = "1" + mem_static = "2Gi" + + @task(requests=Resources(cpu=cpu_static, mem=mem_static)) + def t1(a: int) -> int: + return a + + t1_spec = get_serializable(OrderedDict(), _settings, t1) + assert t1_spec.template.container.resources == task_models.Resources( + requests=[ + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value=cpu_static), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value=mem_static), + ], + limits=[] + ) + + @dynamic(limits=Resources(cpu=cpu_static, mem=mem_static)) + def t1_dynamic(a: int) -> int: + return a + + t1_dynamic_spec = get_serializable(OrderedDict(), _settings, t1_dynamic) + assert t1_dynamic_spec.template.container.resources == task_models.Resources( + requests=[], + limits=[ + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value=cpu_static), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value=mem_static), + ] + ) + + @eager(requests=Resources(cpu=cpu_static, mem=mem_static)) + async def t1_eager(a: int) -> int: + return a + + t1_eager_spec = get_serializable(OrderedDict(), _settings, t1_eager) + assert t1_eager_spec.template.container.resources == task_models.Resources( + requests=[ + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.CPU, value=cpu_static), + task_models.Resources.ResourceEntry(name=task_models.Resources.ResourceName.MEMORY, value=mem_static), + ], + limits=[] + ) diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py index 33e0bbf78f..4992e5a91d 100644 --- a/tests/flytekit/unit/interaction/test_click_types.py +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -15,6 +15,7 @@ from flytekit import FlyteContextManager from flytekit.core.artifact import Artifact +from flytekit.core.resources import ResourceSpec, Resources from flytekit.core.type_engine import TypeEngine from flytekit.interaction.click_types import ( DateTimeType, @@ -28,6 +29,7 @@ StructuredDatasetParamType, UnionParamType, key_value_callback, + resource_callback, ) dummy_param = click.Option(["--dummy"], type=click.STRING, default="dummy") @@ -236,6 +238,19 @@ def test_key_value_callback(): key_value_callback(ctx, "a", ["a=b", "c=d", "e=f", "g"]) +def test_resource_callback(): + ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + assert resource_callback(ctx, "a", None) is None + assert resource_callback(ctx, "a", "cpu=1,mem=2Gi") == Resources(cpu="1", mem="2Gi") + assert resource_callback(ctx, "a", "cpu=1,mem=2Gi,gpu=1") == Resources(cpu="1", mem="2Gi", gpu="1") + with pytest.raises(click.BadParameter, match="Expected comma separated"): + resource_callback(ctx, "a", "cpu=1;mem=2Gi") + with pytest.raises(click.BadParameter, match="Expected key to be one of"): + resource_callback(ctx, "a", "cpu=1,a=b,mem=2Gi") + with pytest.raises(click.BadParameter, match="Expected unique keys"): + resource_callback(ctx, "a", "cpu=1,mem=2Gi,cpu=1") + + @pytest.mark.parametrize( "param_type", [