diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 00b9f59bc5..5801d24fde 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -1,6 +1,7 @@ import dataclasses import os import shutil +import tempfile from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Union, cast @@ -237,10 +238,11 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: and ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION ): + base_dir = tempfile.mkdtemp() file_name = "flyte_wf" file_format = "zip" - shutil.make_archive(file_name, file_format, os.getcwd()) - self.sess.sparkContext.addPyFile(f"{file_name}.{file_format}") + shutil.make_archive(f"{base_dir}/{file_name}", file_format, os.getcwd()) + self.sess.sparkContext.addPyFile(f"{base_dir}/{file_name}.{file_format}") return user_params.builder().add_attr("SPARK_SESSION", self.sess).build() diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index f510af5f24..ff3e1797b7 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -4,6 +4,7 @@ import pandas as pd import pyspark import pytest +import tempfile from google.protobuf.json_format import MessageToDict from flytekit import PodTemplate @@ -157,8 +158,10 @@ def test_to_html(): assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output +@mock.patch("tempfile.mkdtemp", return_value="/tmp/123") +@mock.patch("shutil.make_archive") @mock.patch("pyspark.context.SparkContext.addPyFile") -def test_spark_addPyFile(mock_add_pyfile): +def test_spark_addPyFile(mock_add_pyfile, mock_shutil_make_archive, mock_tempfile_mkdtemp): @task( task_config=Spark( spark_conf={"spark": "1"}, @@ -190,9 +193,10 @@ def my_spark(a: int) -> int: ).with_serialization_settings(serialization_settings) ) as new_ctx: my_spark.pre_execute(new_ctx.user_space_params) - mock_add_pyfile.assert_called_once() - os.remove(os.path.join(os.getcwd(), "flyte_wf.zip")) + mock_tempfile_mkdtemp.assert_called_once() + mock_shutil_make_archive.assert_called_once_with("/tmp/123/flyte_wf", "zip", os.getcwd()) + mock_add_pyfile.assert_called_once_with("/tmp/123/flyte_wf.zip") def test_spark_with_image_spec(): custom_image = ImageSpec(