From c99435b5cbd186d6c32dfaa5240caf7edad8e991 Mon Sep 17 00:00:00 2001 From: redartera Date: Tue, 15 Apr 2025 13:15:53 +0000 Subject: [PATCH] enable pyflyte run --interruptible Signed-off-by: redartera --- flytekit/clis/sdk_in_container/run.py | 12 ++++++ tests/flytekit/unit/cli/pyflyte/test_run.py | 41 +++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 7a08ef31af..c74680e67c 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -186,6 +186,17 @@ class RunLevelParams(PyFlyteParams): help="Whether to overwrite the cache if it already exists", ) ) + interruptible: typing.Optional[bool] = make_click_option_field( + click.Option( + param_decls=["--interruptible"], + type=bool, + required=False, + default=None, + help="Specify if the execution should be forced to run with an interruptible flag of true or false." + " Use '--interruptible true' or '--interruptible false' to explicitly enable/disable." + " If this option is not provided, the default interruptible behavior of the remote Flyte entity is used.", + ) + ) envvars: typing.Dict[str, str] = make_click_option_field( click.Option( param_decls=["--envvars", "--env"], @@ -557,6 +568,7 @@ def run_remote( options=options_from_run_params(run_level_params), type_hints=type_hints, overwrite_cache=run_level_params.overwrite_cache, + interruptible=run_level_params.interruptible, envs=run_level_params.envvars, tags=run_level_params.tags, cluster_pool=run_level_params.cluster_pool, diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index e4ab4145dc..6bcea871e7 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -96,6 +96,47 @@ def test_pyflyte_run_wf(remote, remote_flag, workflow_file): assert result.exit_code == 0 +@pytest.mark.parametrize("workflow_file", [WorkflowFileLocation.NORMAL], indirect=["workflow_file"]) +@pytest.mark.parametrize( + "interruptible_cli_val,expected_interruptible_override", + [ + ("true", True), + ("True", True), + ("TRUE", True), + ("false", False), + ("False", False), + ("FALSE", False), + (None, None), + ], +) +def test_pyflyte_run_wf_interruptible(workflow_file, interruptible_cli_val, expected_interruptible_override): + """Tests that the '--interruptible' option passed to 'pyflyte run' is correctly passed to the remote execution.""" + # Build the pyflyte args + pyflyte_args = [ + "run", + "--remote", + str(workflow_file), + "wf_with_list", + "--a", "[1,2,3]", + ] + # Insert "--interruptible" if interruptible_cli_val is not None + if interruptible_cli_val is not None: + pyflyte_args = pyflyte_args[:2] + ["--interruptible", interruptible_cli_val] + pyflyte_args[2:] + # Run the command - but mock 'FlyteRemote' to check what value 'remote.execute' was called with + with mock.patch("flytekit.configuration.plugin.FlyteRemote") as mocked_remote: + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + pyflyte_args, + catch_exceptions=False, + ) + assert result.exit_code == 0, result.stdout + # Check that 'remote.execute' was called once with the correct interruptible override + assert mocked_remote.return_value.execute.call_count == 1 + assert mocked_remote.return_value.execute.call_args[1]["interruptible"] == expected_interruptible_override + + + def test_pyflyte_run_with_labels(): workflow_file = pathlib.Path(__file__).parent / "workflow.py" with mock.patch("flytekit.configuration.plugin.FlyteRemote"):