From 950676cd320d76386130246afd4ad25d43a603c0 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 11 Mar 2025 17:34:39 +0530 Subject: [PATCH 1/7] mosaic data streaming integration Signed-off-by: Samhita Alla --- flytekit/types/directory/__init__.py | 2 +- flytekit/types/directory/types.py | 183 ++++++++++++++++-- .../unit/types/directory/data_streamer.py | 92 +++++++++ 3 files changed, 259 insertions(+), 18 deletions(-) create mode 100644 tests/flytekit/unit/types/directory/data_streamer.py diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index 83bb0c8fa8..419ee0a51b 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -16,7 +16,7 @@ import typing -from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer +from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer, StreamingKwargs # The following section provides some predefined aliases for commonly used FlyteDirectory formats. diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 699278b0b6..0ea17aace8 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -8,10 +8,11 @@ from dataclasses import dataclass, field from functools import partial from pathlib import Path -from typing import Any, Dict, Generator, Tuple +from typing import Annotated, Any, Dict, Generator, Tuple from uuid import UUID import fsspec +import jsonlines import msgpack from dataclasses_json import DataClassJsonMixin, config from fsspec.utils import get_protocol @@ -19,12 +20,22 @@ from google.protobuf.struct_pb2 import Struct from marshmallow import fields from mashumaro.types import SerializableType +from typing_extensions import get_args, get_origin from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size +from flytekit.core.type_engine import ( + AsyncTypeTransformer, + TypeEngine, + TypeTransformerFailedError, + get_batch_size, +) from flytekit.exceptions.user import FlyteAssertion -from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator +from flytekit.extras.pydantic_transformer.decorator import ( + model_serializer, + model_validator, +) +from flytekit.loggers import logger from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types from flytekit.models.core.types import BlobType @@ -32,6 +43,18 @@ from flytekit.models.types import LiteralType from flytekit.types.file import FileExt, FlyteFile +try: + import streaming # noqa: F401 + + _has_streaming = True +except ImportError: + _has_streaming = False + +if _has_streaming: + from streaming.base import MDSWriter, StreamingDataset +else: + logger.info("Streaming is unavailable.") + T = typing.TypeVar("T") PathType = typing.Union[str, os.PathLike] @@ -39,6 +62,17 @@ def noop(): ... +@dataclass +class StreamingKwargs(DataClassJsonMixin): + shards_config: typing.Dict[str, Any] = field(default=None, metadata=config(mm_field=fields.Dict())) + stream_config: typing.Dict[str, Any] = field(default=None, metadata=config(mm_field=fields.Dict())) + + def __post_init__(self): + columns = self.shards_config.get("columns") + if columns: + self.shards_config["columns"] = {k: v.__name__ if isinstance(v, type) else v for k, v in columns.items()} + + @dataclass class FlyteDirectory(SerializableType, DataClassJsonMixin, os.PathLike, typing.Generic[T]): path: PathType = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore @@ -47,7 +81,8 @@ class FlyteDirectory(SerializableType, DataClassJsonMixin, os.PathLike, typing.G This class should not be used on very large datasets, as merely listing the dataset will cause the entire dataset to be downloaded. Listing on S3 and other backend object stores is not consistent - and we should not need data to be downloaded to list. + and we should not need data to be downloaded to list. If you need to work with large datasets efficiently, + consider using FlyteDirectory with **streaming** instead of downloading everything at once. Please first read through the comments on the :py:class:`flytekit.types.file.FlyteFile` class as the implementation here is similar. @@ -126,6 +161,21 @@ def t1(in1: FlyteDirectory["svg"]): The format [] bit is still there because in Flyte, directories are stored as Blob Types also, just like files, and the Blob type has the format field. The difference in the type field is represented in the ``dimensionality`` field in the ``BlobType``. + + To stream a FlyteDirectory, use the following approach: + + .. code-block:: python + + from typing import Annotated + from flytekit.types.directory.types import FlyteDirectory + + def t2(dataset: Annotated[FlyteDirectory, StreamingKwargs(shards_config={}, stream_config={})]): + # Returns an instance of a subclass of PyTorch's IterableDataset, yielding data samples as an iterator. + for i in range(dataset.num_samples): + print(dataset[i]) + + This leverages MosaicML's streaming library under the hood. + The dataset is represented as a StreamingDataset, which extends PyTorch's IterableDataset, enabling efficient, on-the-fly data loading. """ def _serialize(self) -> typing.Dict[str, str]: @@ -632,20 +682,102 @@ def wf(dc: DC): python_val = json.loads(json_str) return self.dict_to_flyte_directory(python_val, expected_python_type) + def _is_valid_jsonl_file(self, file_path: str) -> bool: + if not os.path.isfile(file_path) or not file_path.endswith(".jsonl"): + return False + try: + with jsonlines.open(file_path) as reader: + for _ in reader: + pass + return True + except (jsonlines.InvalidLineError, UnicodeDecodeError): + return False + + def _create_shards( + self, ctx: FlyteContext, uri: str, fd: FlyteDirectory = None, aa: StreamingKwargs = None + ) -> typing.Union[bool, str]: + if not aa.shards_config: + return None + + # Set output path if not provided + if "out" not in aa.shards_config: + aa.shards_config["out"] = ctx.file_access.get_random_local_directory() + + # Create shards + with MDSWriter(**aa.shards_config) as out: + if fd: # Remote directory case + for base, x in fd.crawl(): + src = str(os.path.join(base, x)) + if self._is_valid_jsonl_file(src): + with jsonlines.open(src) as reader: + for obj in reader: + out.write(obj) + else: # Local directory case + uri_path = Path(uri) + for src in uri_path.glob("*.jsonl"): + if self._is_valid_jsonl_file(str(src)): + with jsonlines.open(src) as reader: + for obj in reader: + out.write(obj) + + return aa.shards_config["out"] + + def _process_directory_with_streaming( + self, + ctx: FlyteContext, + uri: str, + fd: FlyteDirectory = None, + base_type: typing.Type = None, + annotate_args: typing.List = None, + expected_python_type: typing.Type = None, + ) -> typing.Union[StreamingDataset, FlyteDirectory]: + for aa in annotate_args: + if isinstance(aa, StreamingKwargs): + # Process shards if configured + output_path = None + if aa.shards_config: + output_path = self._create_shards(ctx, uri, fd, aa) + + # Process streaming configuration if present + if aa.stream_config: + if output_path: # If shards were created + return StreamingDataset(local=output_path, **aa.stream_config) + else: # Use original path (assuming shards are present in the remote directory) + return StreamingDataset(remote=uri, **aa.stream_config) + + # Return appropriate object if we didn't create a StreamingDataset + if fd: # Remote case + return fd + else: # Local case + return base_type(uri, remote_directory=False) + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory] - ) -> FlyteDirectory: + ) -> typing.Any: + base_type = None + has_streaming_kwargs = False + annotate_args = [] + + # Extract base type and annotations for StreamingKwargs + if get_origin(expected_python_type) is Annotated: + base_type, *annotate_args = get_args(expected_python_type) + has_streaming_kwargs = any(isinstance(arg, StreamingKwargs) for arg in annotate_args) + if has_streaming_kwargs and not _has_streaming: + raise TypeTransformerFailedError( + "In order to use StreamingKwargs, you need to install mosaicml-streaming first." + ) + # Handle dataclass attribute access if lv.scalar: if lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + return self.from_binary_idl(lv.scalar.binary, base_type or expected_python_type) if lv.scalar.generic: - return self.from_generic_idl(lv.scalar.generic, expected_python_type) + return self.from_generic_idl(lv.scalar.generic, base_type or expected_python_type) try: uri = lv.scalar.blob.uri except AttributeError: - raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {base_type or expected_python_type}") if lv.scalar.blob.metadata.type.dimensionality != BlobType.BlobDimensionality.MULTIPART: raise TypeTransformerFailedError(f"{lv.scalar.blob.uri} is not a directory.") @@ -653,22 +785,39 @@ async def async_to_python_value( if not ctx.file_access.is_remote(uri) and not os.path.isdir(uri): raise FlyteAssertion(f"Expected a directory, but the given uri '{uri}' is not a directory.") - # This is a local file path, like /usr/local/my_dir, don't mess with it. Certainly, downloading it doesn't - # make any sense. + # Local file path handling if not ctx.file_access.is_remote(uri): - return expected_python_type(uri, remote_directory=False) + if base_type and has_streaming_kwargs: + return self._process_directory_with_streaming( + ctx, + uri, + fd=None, + base_type=base_type, + annotate_args=annotate_args, + expected_python_type=expected_python_type, + ) + else: + return (base_type or expected_python_type)(uri, remote_directory=False) - # For the remote case, return a FlyteDirectory object that can download + # Remote file path handling local_folder = ctx.file_access.get_random_local_directory() - - batch_size = get_batch_size(expected_python_type) - + batch_size = get_batch_size(base_type or expected_python_type) _downloader = partial(ctx.file_access.get_data, uri, local_folder, is_multipart=True, batch_size=batch_size) - - expected_format = self.get_format(expected_python_type) + expected_format = self.get_format(base_type or expected_python_type) fd = FlyteDirectory.__class_getitem__(expected_format)(local_folder, _downloader) fd._remote_source = uri + + if base_type and has_streaming_kwargs: + return self._process_directory_with_streaming( + ctx, + uri, + fd=fd, + base_type=base_type, + annotate_args=annotate_args, + expected_python_type=expected_python_type, + ) + return fd def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteDirectory[typing.Any]]: diff --git a/tests/flytekit/unit/types/directory/data_streamer.py b/tests/flytekit/unit/types/directory/data_streamer.py new file mode 100644 index 0000000000..a3640eead6 --- /dev/null +++ b/tests/flytekit/unit/types/directory/data_streamer.py @@ -0,0 +1,92 @@ +import json +import math +import os +from typing import Annotated + +from tqdm import tqdm + +import flytekit +from flytekit.types.directory import FlyteDirectory, StreamingKwargs + +image = flytekit.ImageSpec(name="mosaic-data-streaming", packages=["datasets", ""]) + + +@flytekit.task(cache=True) +def prepare_hf_dataset_to_jsonl( + dataset_name: str = "CohereForAI/aya_collection_language_split", + language: str = "algerian_arabic", + samples_per_file: int = 1000, + split: str = "validation", +) -> FlyteDirectory: + """Data preparation step.""" + from datasets import load_dataset + + output_dir = os.path.join( + flytekit.current_context().working_directory, "aya_algerian_arabic_jsonl" + ) + + os.makedirs(output_dir, exist_ok=True) + dataset = load_dataset(dataset_name, language, split=split) + total_samples = len(dataset) + num_files = math.ceil(total_samples / samples_per_file) + + for file_idx in range(num_files): + start_idx = file_idx * samples_per_file + end_idx = min((file_idx + 1) * samples_per_file, total_samples) + + output_file = os.path.join(output_dir, f"{split}_{file_idx:05d}.jsonl") + with open(output_file, "w", encoding="utf-8") as f: + subset = dataset.select(range(start_idx, end_idx)) + for sample in tqdm(subset, total=end_idx - start_idx): + f.write(json.dumps(dict(sample), ensure_ascii=False) + "\n") + + return output_dir + + +@flytekit.task +def stream_data( + dataset: Annotated[ + FlyteDirectory, + StreamingKwargs( + shards_config={ + "columns": { + "id": "int", + "inputs": str, + "targets": str, + "dataset_name": str, + "sub_dataset_name": str, + "task_type": str, + "template_id": int, + "language": str, + "split": str, + "script": str, + } + }, + stream_config={ + "batch_size": 5, + "sampling_method": "fixed", + "download_retry": 2, + "download_timeout": 120, + "sampling_granularity": 1, + }, + ), + ], +) -> int: + return dataset.num_samples + + +@flytekit.task +def read_dir_without_streaming(dataset: FlyteDirectory) -> int: + return 1 + + +@flytekit.workflow +def mosaic_data_streaming() -> int: + jsonl_dir = prepare_hf_dataset_to_jsonl() + return stream_data(dataset=jsonl_dir) + + +@flytekit.workflow +def dir_without_streaming() -> int: + jsonl_dir = prepare_hf_dataset_to_jsonl() + return read_dir_without_streaming(dataset=jsonl_dir) From 97d252bbee40f6b01398d11788e850dd2a8614de Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 11 Mar 2025 19:29:36 +0530 Subject: [PATCH 2/7] fix remote dir code Signed-off-by: Samhita Alla --- flytekit/types/directory/types.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 0ea17aace8..e3a67ff813 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -708,8 +708,9 @@ def _create_shards( if fd: # Remote directory case for base, x in fd.crawl(): src = str(os.path.join(base, x)) - if self._is_valid_jsonl_file(src): - with jsonlines.open(src) as reader: + local_src_file = FlyteFile(src).download() + if self._is_valid_jsonl_file(local_src_file): + with jsonlines.open(local_src_file) as reader: for obj in reader: out.write(obj) else: # Local directory case From 1fabfeae1dc2f8f24cf0f7273615ee7573dc3dce Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 11 Mar 2025 19:34:34 +0530 Subject: [PATCH 3/7] fix remote dir code Signed-off-by: Samhita Alla --- flytekit/types/directory/types.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index e3a67ff813..44e09981e1 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -683,8 +683,6 @@ def wf(dc: DC): return self.dict_to_flyte_directory(python_val, expected_python_type) def _is_valid_jsonl_file(self, file_path: str) -> bool: - if not os.path.isfile(file_path) or not file_path.endswith(".jsonl"): - return False try: with jsonlines.open(file_path) as reader: for _ in reader: @@ -708,7 +706,7 @@ def _create_shards( if fd: # Remote directory case for base, x in fd.crawl(): src = str(os.path.join(base, x)) - local_src_file = FlyteFile(src).download() + local_src_file = FlyteFile.from_source(src) if self._is_valid_jsonl_file(local_src_file): with jsonlines.open(local_src_file) as reader: for obj in reader: From 7e2dc8a0d0da3c2bebbe5f1c0b2193898d122872 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 12 Mar 2025 16:14:40 +0530 Subject: [PATCH 4/7] add data format Signed-off-by: Samhita Alla --- flytekit/types/directory/types.py | 143 ++++++++++++++++++++++-------- 1 file changed, 106 insertions(+), 37 deletions(-) diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 44e09981e1..76c2c487fa 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -6,6 +6,7 @@ import random import typing from dataclasses import dataclass, field +from enum import Enum from functools import partial from pathlib import Path from typing import Annotated, Any, Dict, Generator, Tuple @@ -18,7 +19,7 @@ from fsspec.utils import get_protocol from google.protobuf import json_format as _json_format from google.protobuf.struct_pb2 import Struct -from marshmallow import fields +from marshmallow import fields, validate from mashumaro.types import SerializableType from typing_extensions import get_args, get_origin @@ -62,13 +63,23 @@ def noop(): ... +class DataFormat(Enum): + JSONL = "jsonl" + PARQUET = "parquet" + ARROW = "arrow" + + @dataclass class StreamingKwargs(DataClassJsonMixin): shards_config: typing.Dict[str, Any] = field(default=None, metadata=config(mm_field=fields.Dict())) stream_config: typing.Dict[str, Any] = field(default=None, metadata=config(mm_field=fields.Dict())) + data_format: str = field( + default=DataFormat.JSONL, + metadata=config(mm_field=fields.String(validate=validate.OneOf([format.value for format in DataFormat]))), + ) def __post_init__(self): - columns = self.shards_config.get("columns") + columns = self.shards_config.get("columns") if self.shards_config else None if columns: self.shards_config["columns"] = {k: v.__name__ if isinstance(v, type) else v for k, v in columns.items()} @@ -691,33 +702,100 @@ def _is_valid_jsonl_file(self, file_path: str) -> bool: except (jsonlines.InvalidLineError, UnicodeDecodeError): return False + def _is_valid_parquet_file(self, file_path: str) -> bool: + import pyarrow.parquet as pq + + try: + pq.ParquetFile(file_path) + return True + except Exception: + return False + + def _is_valid_arrow_file(self, file_path: str) -> bool: + import pyarrow as pa + + try: + with pa.memory_map(file_path, "r") as mmap: + pa.RecordBatchStreamReader(mmap) + return True + except Exception: + return False + + def _write_jsonl(self, out, src: str): + if self._is_valid_jsonl_file(src): + with jsonlines.open(src) as reader: + for obj in reader: + out.write(obj) + else: + raise ValueError(f"Invalid JSONL file: {src}") + + def _process_batches(self, out, reader): + import numpy as np + + for batch in reader: + records = ( + { + name: np.array(val.as_py()) + if hasattr(val, "as_py") and isinstance(val.as_py(), list) + else val.as_py() + if hasattr(val, "as_py") + else np.array(val) + if isinstance(val, list) + else val + for name, val in zip(batch.schema.names, row) + } + for row in zip(*batch.columns) + ) + for record in records: + out.write(record) + + def _write_parquet_or_arrow(self, out, src: str, is_parquet: bool = True): + import pyarrow as pa + import pyarrow.parquet as pq + + if is_parquet: + if not self._is_valid_parquet_file(src): + raise ValueError(f"Invalid Parquet file: {src}") + reader = pq.ParquetFile(src).iter_batches(batch_size=5) + else: + if not self._is_valid_arrow_file(src): + raise ValueError(f"Invalid Arrow file: {src}") + with pa.memory_map(src, "r") as mmap, pa.RecordBatchStreamReader(mmap) as reader: + self._process_batches(out, reader) + return + + self._process_batches(out, reader) + def _create_shards( self, ctx: FlyteContext, uri: str, fd: FlyteDirectory = None, aa: StreamingKwargs = None ) -> typing.Union[bool, str]: if not aa.shards_config: return None - # Set output path if not provided - if "out" not in aa.shards_config: - aa.shards_config["out"] = ctx.file_access.get_random_local_directory() + aa.shards_config.setdefault("out", ctx.file_access.get_random_local_directory()) + + if aa.data_format == DataFormat.JSONL: + writer_func = self._write_jsonl + elif aa.data_format == DataFormat.PARQUET: + writer_func = self._write_parquet_or_arrow + elif aa.data_format == DataFormat.ARROW: + writer_func = partial(self._write_parquet_or_arrow, is_parquet=False) + else: + raise ValueError(f"Unsupported data format: {aa.data_format}") - # Create shards with MDSWriter(**aa.shards_config) as out: - if fd: # Remote directory case - for base, x in fd.crawl(): - src = str(os.path.join(base, x)) - local_src_file = FlyteFile.from_source(src) - if self._is_valid_jsonl_file(local_src_file): - with jsonlines.open(local_src_file) as reader: - for obj in reader: - out.write(obj) - else: # Local directory case - uri_path = Path(uri) - for src in uri_path.glob("*.jsonl"): - if self._is_valid_jsonl_file(str(src)): - with jsonlines.open(src) as reader: - for obj in reader: - out.write(obj) + if fd: + sources = (FlyteFile.from_source(str(Path(base) / x)) for base, x in fd.crawl()) + else: + sources = Path(uri).rglob(f"*.{aa.data_format.name.lower()}") + + try: + first_source = next(sources) # Check if at least one file exists + writer_func(out, str(first_source)) + for src in sources: # Continue processing the rest + writer_func(out, str(src)) + except StopIteration: + raise ValueError(f"No {aa.data_format.name.lower()} files found in {uri}") return aa.shards_config["out"] @@ -728,27 +806,20 @@ def _process_directory_with_streaming( fd: FlyteDirectory = None, base_type: typing.Type = None, annotate_args: typing.List = None, - expected_python_type: typing.Type = None, ) -> typing.Union[StreamingDataset, FlyteDirectory]: for aa in annotate_args: if isinstance(aa, StreamingKwargs): # Process shards if configured - output_path = None - if aa.shards_config: - output_path = self._create_shards(ctx, uri, fd, aa) + output_path = self._create_shards(ctx, uri, fd, aa) if aa.shards_config else None - # Process streaming configuration if present - if aa.stream_config: - if output_path: # If shards were created - return StreamingDataset(local=output_path, **aa.stream_config) - else: # Use original path (assuming shards are present in the remote directory) - return StreamingDataset(remote=uri, **aa.stream_config) + return StreamingDataset( + local=output_path or uri if output_path or not fd else None, + remote=uri if not output_path and fd else None, + **(aa.stream_config or {}), # Make config optional + ) # Return appropriate object if we didn't create a StreamingDataset - if fd: # Remote case - return fd - else: # Local case - return base_type(uri, remote_directory=False) + return fd or base_type(uri, remote_directory=False) async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Type[FlyteDirectory] @@ -793,7 +864,6 @@ async def async_to_python_value( fd=None, base_type=base_type, annotate_args=annotate_args, - expected_python_type=expected_python_type, ) else: return (base_type or expected_python_type)(uri, remote_directory=False) @@ -814,7 +884,6 @@ async def async_to_python_value( fd=fd, base_type=base_type, annotate_args=annotate_args, - expected_python_type=expected_python_type, ) return fd From a0434c4753b5b7a9a818ee7fa03892d7dd8227da Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 12 Mar 2025 16:26:27 +0530 Subject: [PATCH 5/7] add init file Signed-off-by: Samhita Alla --- flytekit/types/directory/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index 419ee0a51b..e172f71557 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -16,7 +16,7 @@ import typing -from .types import FlyteDirectory, FlyteDirToMultipartBlobTransformer, StreamingKwargs +from .types import DataFormat, FlyteDirectory, FlyteDirToMultipartBlobTransformer, StreamingKwargs # The following section provides some predefined aliases for commonly used FlyteDirectory formats. From 1781d4f9af97ddb5697246ddcee1fd6c873a6221 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 13 Mar 2025 19:41:34 +0530 Subject: [PATCH 6/7] don't raise exceptions Signed-off-by: Samhita Alla --- flytekit/types/directory/types.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 76c2c487fa..e55c3cdf80 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -726,8 +726,6 @@ def _write_jsonl(self, out, src: str): with jsonlines.open(src) as reader: for obj in reader: out.write(obj) - else: - raise ValueError(f"Invalid JSONL file: {src}") def _process_batches(self, out, reader): import numpy as np @@ -754,17 +752,15 @@ def _write_parquet_or_arrow(self, out, src: str, is_parquet: bool = True): import pyarrow.parquet as pq if is_parquet: - if not self._is_valid_parquet_file(src): - raise ValueError(f"Invalid Parquet file: {src}") - reader = pq.ParquetFile(src).iter_batches(batch_size=5) - else: - if not self._is_valid_arrow_file(src): - raise ValueError(f"Invalid Arrow file: {src}") - with pa.memory_map(src, "r") as mmap, pa.RecordBatchStreamReader(mmap) as reader: + if self._is_valid_parquet_file(src): + reader = pq.ParquetFile(src).iter_batches(batch_size=5) self._process_batches(out, reader) - return - - self._process_batches(out, reader) + else: + if self._is_valid_arrow_file(src): + with pa.memory_map(src, "r") as mmap, pa.RecordBatchStreamReader(mmap) as reader: + self._process_batches(out, reader) + return + return def _create_shards( self, ctx: FlyteContext, uri: str, fd: FlyteDirectory = None, aa: StreamingKwargs = None @@ -785,7 +781,7 @@ def _create_shards( with MDSWriter(**aa.shards_config) as out: if fd: - sources = (FlyteFile.from_source(str(Path(base) / x)) for base, x in fd.crawl()) + sources = (FlyteFile.from_source(os.path.join(base, x)).download() for base, x in fd.crawl()) else: sources = Path(uri).rglob(f"*.{aa.data_format.name.lower()}") @@ -793,7 +789,7 @@ def _create_shards( first_source = next(sources) # Check if at least one file exists writer_func(out, str(first_source)) for src in sources: # Continue processing the rest - writer_func(out, str(src)) + writer_func(out, src) except StopIteration: raise ValueError(f"No {aa.data_format.name.lower()} files found in {uri}") From 9438bbc3d66a00709b3e84e53022b11d3524de52 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 17 Apr 2025 22:05:53 +0530 Subject: [PATCH 7/7] remove examples Signed-off-by: Samhita Alla --- .../unit/types/directory/data_streamer.py | 92 ------------------- 1 file changed, 92 deletions(-) delete mode 100644 tests/flytekit/unit/types/directory/data_streamer.py diff --git a/tests/flytekit/unit/types/directory/data_streamer.py b/tests/flytekit/unit/types/directory/data_streamer.py deleted file mode 100644 index a3640eead6..0000000000 --- a/tests/flytekit/unit/types/directory/data_streamer.py +++ /dev/null @@ -1,92 +0,0 @@ -import json -import math -import os -from typing import Annotated - -from tqdm import tqdm - -import flytekit -from flytekit.types.directory import FlyteDirectory, StreamingKwargs - -image = flytekit.ImageSpec(name="mosaic-data-streaming", packages=["datasets", ""]) - - -@flytekit.task(cache=True) -def prepare_hf_dataset_to_jsonl( - dataset_name: str = "CohereForAI/aya_collection_language_split", - language: str = "algerian_arabic", - samples_per_file: int = 1000, - split: str = "validation", -) -> FlyteDirectory: - """Data preparation step.""" - from datasets import load_dataset - - output_dir = os.path.join( - flytekit.current_context().working_directory, "aya_algerian_arabic_jsonl" - ) - - os.makedirs(output_dir, exist_ok=True) - dataset = load_dataset(dataset_name, language, split=split) - total_samples = len(dataset) - num_files = math.ceil(total_samples / samples_per_file) - - for file_idx in range(num_files): - start_idx = file_idx * samples_per_file - end_idx = min((file_idx + 1) * samples_per_file, total_samples) - - output_file = os.path.join(output_dir, f"{split}_{file_idx:05d}.jsonl") - with open(output_file, "w", encoding="utf-8") as f: - subset = dataset.select(range(start_idx, end_idx)) - for sample in tqdm(subset, total=end_idx - start_idx): - f.write(json.dumps(dict(sample), ensure_ascii=False) + "\n") - - return output_dir - - -@flytekit.task -def stream_data( - dataset: Annotated[ - FlyteDirectory, - StreamingKwargs( - shards_config={ - "columns": { - "id": "int", - "inputs": str, - "targets": str, - "dataset_name": str, - "sub_dataset_name": str, - "task_type": str, - "template_id": int, - "language": str, - "split": str, - "script": str, - } - }, - stream_config={ - "batch_size": 5, - "sampling_method": "fixed", - "download_retry": 2, - "download_timeout": 120, - "sampling_granularity": 1, - }, - ), - ], -) -> int: - return dataset.num_samples - - -@flytekit.task -def read_dir_without_streaming(dataset: FlyteDirectory) -> int: - return 1 - - -@flytekit.workflow -def mosaic_data_streaming() -> int: - jsonl_dir = prepare_hf_dataset_to_jsonl() - return stream_data(dataset=jsonl_dir) - - -@flytekit.workflow -def dir_without_streaming() -> int: - jsonl_dir = prepare_hf_dataset_to_jsonl() - return read_dir_without_streaming(dataset=jsonl_dir)