From 6d7148eef4c36f38f79b30cf2914c2ea1a8c10c6 Mon Sep 17 00:00:00 2001 From: rambrus Date: Tue, 15 Apr 2025 10:51:26 +0200 Subject: [PATCH 1/4] Bugfix for issue #6405: Fast registration runs into infinite loop Signed-off-by: rambrus --- plugins/flytekit-spark/flytekitplugins/spark/task.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 00b9f59bc5..b9a0265f41 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -1,6 +1,7 @@ import dataclasses import os import shutil +import tempfile from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Union, cast @@ -237,10 +238,11 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: and ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION ): + base_dir = tempfile.TemporaryDirectory().name file_name = "flyte_wf" file_format = "zip" - shutil.make_archive(file_name, file_format, os.getcwd()) - self.sess.sparkContext.addPyFile(f"{file_name}.{file_format}") + shutil.make_archive(f"{base_dir}/{file_name}", file_format, os.getcwd()) + self.sess.sparkContext.addPyFile(f"{base_dir}/{file_name}.{file_format}") return user_params.builder().add_attr("SPARK_SESSION", self.sess).build() From 1172159f478fed3b16ee1ea0bff5ea7c4b6c30bb Mon Sep 17 00:00:00 2001 From: rambrus Date: Tue, 15 Apr 2025 11:21:45 +0200 Subject: [PATCH 2/4] Replace tempfile.TemporaryDirectory().name with tempfile.mkdtemp() Signed-off-by: rambrus --- plugins/flytekit-spark/flytekitplugins/spark/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index b9a0265f41..5801d24fde 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -238,7 +238,7 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: and ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION ): - base_dir = tempfile.TemporaryDirectory().name + base_dir = tempfile.mkdtemp() file_name = "flyte_wf" file_format = "zip" shutil.make_archive(f"{base_dir}/{file_name}", file_format, os.getcwd()) From 854d54aec7c4136b9eb02ce54ae3a7379880d056 Mon Sep 17 00:00:00 2001 From: rambrus Date: Tue, 15 Apr 2025 16:46:27 +0200 Subject: [PATCH 3/4] Update and add unit tests Signed-off-by: rambrus --- .../flytekit-spark/tests/test_spark_task.py | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index f510af5f24..dd30f6f3b8 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -4,6 +4,7 @@ import pandas as pd import pyspark import pytest +import tempfile from google.protobuf.json_format import MessageToDict from flytekit import PodTemplate @@ -191,8 +192,46 @@ def my_spark(a: int) -> int: ) as new_ctx: my_spark.pre_execute(new_ctx.user_space_params) mock_add_pyfile.assert_called_once() - os.remove(os.path.join(os.getcwd(), "flyte_wf.zip")) +@mock.patch("tempfile.mkdtemp", return_value="/tmp/123") +@mock.patch("shutil.make_archive") +@mock.patch("pyspark.context.SparkContext.addPyFile") +def test_spark_archive_created_in_temp_dir(mock_add_pyfile, mock_shutil_make_archive, mock_tempfile_mkdtemp): + @task( + task_config=Spark( + spark_conf={"spark": "1"}, + ) + ) + def my_spark(a: int) -> int: + return a + + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + fast_serialization_settings=FastSerializationSettings( + enabled=True, + destination_dir="/User/flyte/workflows", + distribution_location="s3://my-s3-bucket/fast/123", + ), + ) + + ctx = context_manager.FlyteContextManager.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.new_execution_state().with_params( + mode=ExecutionState.Mode.TASK_EXECUTION + ) + ).with_serialization_settings(serialization_settings) + ) as new_ctx: + my_spark.pre_execute(new_ctx.user_space_params) + + mock_tempfile_mkdtemp.assert_called_once() + mock_shutil_make_archive.assert_called_once_with("/tmp/123/flyte_wf", "zip", os.getcwd()) + mock_add_pyfile.assert_called_once_with("/tmp/123/flyte_wf.zip") def test_spark_with_image_spec(): custom_image = ImageSpec( From cb153b491f94388140b49f876a166380f1d93453 Mon Sep 17 00:00:00 2001 From: rambrus Date: Wed, 16 Apr 2025 07:46:28 +0200 Subject: [PATCH 4/4] Merge unit tests, fix lint issues Signed-off-by: rambrus --- .../flytekit-spark/tests/test_spark_task.py | 39 +------------------ 1 file changed, 2 insertions(+), 37 deletions(-) diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index dd30f6f3b8..ff3e1797b7 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -158,45 +158,10 @@ def test_to_html(): assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output -@mock.patch("pyspark.context.SparkContext.addPyFile") -def test_spark_addPyFile(mock_add_pyfile): - @task( - task_config=Spark( - spark_conf={"spark": "1"}, - ) - ) - def my_spark(a: int) -> int: - return a - - default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = SerializationSettings( - project="project", - domain="domain", - version="version", - env={"FOO": "baz"}, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - fast_serialization_settings=FastSerializationSettings( - enabled=True, - destination_dir="/User/flyte/workflows", - distribution_location="s3://my-s3-bucket/fast/123", - ), - ) - - ctx = context_manager.FlyteContextManager.current_context() - with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.new_execution_state().with_params( - mode=ExecutionState.Mode.TASK_EXECUTION - ) - ).with_serialization_settings(serialization_settings) - ) as new_ctx: - my_spark.pre_execute(new_ctx.user_space_params) - mock_add_pyfile.assert_called_once() - @mock.patch("tempfile.mkdtemp", return_value="/tmp/123") @mock.patch("shutil.make_archive") @mock.patch("pyspark.context.SparkContext.addPyFile") -def test_spark_archive_created_in_temp_dir(mock_add_pyfile, mock_shutil_make_archive, mock_tempfile_mkdtemp): +def test_spark_addPyFile(mock_add_pyfile, mock_shutil_make_archive, mock_tempfile_mkdtemp): @task( task_config=Spark( spark_conf={"spark": "1"}, @@ -230,7 +195,7 @@ def my_spark(a: int) -> int: my_spark.pre_execute(new_ctx.user_space_params) mock_tempfile_mkdtemp.assert_called_once() - mock_shutil_make_archive.assert_called_once_with("/tmp/123/flyte_wf", "zip", os.getcwd()) + mock_shutil_make_archive.assert_called_once_with("/tmp/123/flyte_wf", "zip", os.getcwd()) mock_add_pyfile.assert_called_once_with("/tmp/123/flyte_wf.zip") def test_spark_with_image_spec():