diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9b1b18e4..7dc1c8a95 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: rev: 23.9.1 hooks: - id: black - language_version: python3.10 + language_version: python3.11 stages: [pre-commit] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.0.278 diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml index be9060ee6..1f3aadb00 100644 --- a/config_files/training/config_example_coca.yaml +++ b/config_files/training/config_example_coca.yaml @@ -27,7 +27,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_num_steps config: num_steps: ${settings.training_target.num_target_steps} @@ -36,7 +36,7 @@ settings: sequence_length: ${settings.step_profile.sequence_length} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_samples config: num_ranks: ${settings.cuda_env.world_size} diff --git a/config_files/training/config_lorem_ipsum.yaml b/config_files/training/config_lorem_ipsum.yaml index 670610da6..231f59dc6 100644 --- a/config_files/training/config_lorem_ipsum.yaml +++ b/config_files/training/config_lorem_ipsum.yaml @@ -27,7 +27,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -36,7 +36,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 34acf6263..0a88b1387 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -13,7 +13,9 @@ from modalities.api import ( convert_pytorch_to_hf_checkpoint, - create_raw_data_index, + create_global_index, + create_local_index, + create_shuffled_global_index, generate_text, merge_packed_data_files, pack_encoded_data, @@ -34,6 +36,7 @@ from modalities.running_env.cuda_env import CudaEnv from modalities.trainer import Trainer from modalities.util import get_total_number_of_trainable_parameters, print_rank_0 +from modalities.utils.logging import get_logger @click.group() @@ -123,7 +126,7 @@ def data(): pass -@data.command(name="create_raw_index") +@data.command(name="create_local_index") @click.argument("src_path", type=Path) @click.option( "--index_path", @@ -131,7 +134,7 @@ def data(): default=None, help="output path for index. will use parent directory of src_path if none.", ) -def CMD_entry_point_data_create_raw_index(src_path: Path, index_path: Path): +def CMD_entry_point_data_create_local_index(src_path: Path, index_path: Path): """Utility CMD for indexing the content of a large jsonl-file. Background is the ability to further process the respective file without loading it, while splitting its content line-based. This step is necessary in advance of further processing like tokenization. @@ -144,11 +147,27 @@ def CMD_entry_point_data_create_raw_index(src_path: Path, index_path: Path): Raises: ValueError: If the index file already exists. """ - create_raw_data_index(src_path=src_path, index_path=index_path) + create_local_index(src_path=src_path, index_path=index_path) + + +@data.command(name="create_global_index") +@click.option("--file_list_path", type=Path, required=True) +@click.option("--root_index_path", type=Path, required=True) +@click.option("--global_index_root_path", type=Path, required=True) +def CMD_entry_point_create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path): + create_global_index( + file_list_path=file_list_path, root_index_path=root_index_path, global_index_root_path=global_index_root_path + ) + + +@data.command(name="create_shuffled_global_index") +@click.option("--global_index_file_path", type=Path, required=True) +def CMD_entry_point_create_shuffled_global_index(global_index_file_path: Path): + create_shuffled_global_index(global_index_file_path=global_index_file_path) @data.command(name="pack_encoded_data") -@click.argument("config_path", type=FilePath) +@click.option("--config_path", type=FilePath, required=True) def CMD_entry_point_pack_encoded_data(config_path: FilePath): """Utility to encode an indexed, large jsonl-file. (see also `create_index` for more information) @@ -158,6 +177,7 @@ def CMD_entry_point_pack_encoded_data(config_path: FilePath): Args: config_path (FilePath): Path to the config file describing the tokenization setup. """ + get_logger().info(f"Loading config from {config_path}.") config_dict = load_app_config_dict(config_path) pack_encoded_data(config_dict=config_dict) diff --git a/src/modalities/api.py b/src/modalities/api.py index 05f8ef2c2..ab75a9f48 100644 --- a/src/modalities/api.py +++ b/src/modalities/api.py @@ -1,23 +1,42 @@ #!/usr/bin/env python +import multiprocessing as mp import os +from enum import Enum from pathlib import Path from pydantic import FilePath +import modalities.dataloader.preprocessing.indexation.global_indexation as global_indexation import modalities.inference.inference as inference from modalities.checkpointing.checkpoint_conversion import CheckpointConversion from modalities.config.component_factory import ComponentFactory -from modalities.config.instantiation_models import PackedDatasetComponentsInstantiationModel -from modalities.dataloader.create_index import IndexGenerator -from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data -from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader +from modalities.config.instantiation_models import TokenizationInstantiationModel +from modalities.dataloader.preprocessing.indexation.local_indexation import IndexGenerator +from modalities.dataloader.preprocessing.queued_processing.process_controller import PipelineStep, ProcessController +from modalities.dataloader.preprocessing.queued_processing.processors import Processor +from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import ( + EmbeddedStreamData, + join_embedded_stream_data, +) +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader +from modalities.dataloader.preprocessing.tokenization.strategies import ProcessingStrategyFactory, WorkerTypes from modalities.models.huggingface_adapters.hf_adapter import HFModelAdapter from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry +from modalities.utils.env_variables import temporary_env_vars_decorator +from modalities.utils.logging import get_logger -def create_raw_data_index(src_path: Path, index_path: Path): +class FileExistencePolicy(Enum): + SKIP = "skip" + ERROR = "error" + OVERRIDE = "override" + + +def create_local_index( + src_path: Path, index_path: Path, file_existence_policy: FileExistencePolicy = FileExistencePolicy.ERROR +): """Creates the index file for the content of a large jsonl-file. The index file contains the byte-offsets and lengths of each line in the jsonl-file. Background is the ability to further process the respective file without loading it, @@ -31,17 +50,40 @@ def create_raw_data_index(src_path: Path, index_path: Path): Raises: ValueError: If the index file already exists. """ - index_path = LargeFileLinesReader.default_index_path(src_path, index_path) - os.makedirs(index_path.parent, exist_ok=True) + index_path = LocalLargeFileLinesReader.default_index_path(src_path, index_path) if index_path.exists(): - raise ValueError("index already exists. delete it or specify different output folder.") + if file_existence_policy == FileExistencePolicy.SKIP: + get_logger(name="main").warning(f"Index already exists at {str(index_path)}. Skipping index creation.") + return + elif file_existence_policy == FileExistencePolicy.OVERRIDE: + get_logger(name="main").warning(f"Index already exists at {str(index_path)}. Overriding it.") + os.remove(index_path) + elif file_existence_policy == FileExistencePolicy.ERROR: + raise ValueError("index already exists. delete it or specify different output folder.") + else: + raise ValueError(f"Unknown file existence policy: {file_existence_policy}") + + get_logger(name="main").info( + f"Reading raw data from {str(src_path)} and" f" writing index to {str(index_path)} ..." + ) + os.makedirs(index_path.parent, exist_ok=True) - print(f"reading raw data from {src_path}") - print(f"writing index to {index_path}") generator = IndexGenerator(src_path) generator.create_index(index_path) +def create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path) -> Path: + global_index_file_path = global_indexation.create_global_index( + file_list_path, root_index_path, global_index_root_path + ) + return global_index_file_path + + +def create_shuffled_global_index(global_index_file_path: Path) -> Path: + global_shuffled_index_file_path = global_indexation.create_shuffled_global_index(global_index_file_path) + return global_shuffled_index_file_path + + def generate_text(config_file_path: FilePath): """Inference function to generate text with a given model. @@ -70,6 +112,9 @@ def convert_pytorch_to_hf_checkpoint( return hf_model +# not setting this can cause deadlocks when using hf's "FastTokenizers". See also: +# https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning/67254879#67254879 +@temporary_env_vars_decorator({"TOKENIZERS_PARALLELISM": "false"}) def pack_encoded_data(config_dict: dict): """Packs and encodes an indexed, large jsonl-file. (see also `create_index` for more information) @@ -79,31 +124,129 @@ def pack_encoded_data(config_dict: dict): Args: config_dict (dict): Dictionary containing the configuration for the packed data generation. """ - - # TODO: if we want to use alternative entrypoints together with the ResolverRegistry, - # we can currently not rely on the existing class resolver. - # This is based on its connection to the overall `AppConfig`. - # One would requires an object of it to instantiate the ResolverRegistry. - # This could get resolved by implementing on own ResolverRegistry for each entrypoint or adapting the existing - # ResolverRegistry to work dynamically with any type-hinted config object from config.py. registry = Registry(COMPONENTS) component_factory = ComponentFactory(registry=registry) - components: PackedDatasetComponentsInstantiationModel = component_factory.build_components( - config_dict=config_dict, components_model_type=PackedDatasetComponentsInstantiationModel + instantion_model: TokenizationInstantiationModel = component_factory.build_components( + config_dict=config_dict, components_model_type=TokenizationInstantiationModel + ) + + # build the queues + reader_q, tokenizer_q, writer_q, logging_message_q = ProcessingStrategyFactory.get_process_queues( + reader_q_maxsize=instantion_model.reader_q_maxsize, + writer_q_maxsize=instantion_model.writer_q_maxsize, + tokenizer_q_maxsize=instantion_model.tokenizer_q_maxsize, + ) + + # build the workers + stop_event = mp.Event() + reader_q_key = "reader_q" + tokenizer_q_key = "tokenizer_q" + writer_q_key = "writer_q" + logging_message_q_key = "logging_message_q" + + populating_worker = Processor( + out_qs={reader_q_key: reader_q, logging_message_q_key: logging_message_q}, + in_q_timeout=instantion_model.in_q_timeout, + out_q_timeout=instantion_model.out_q_timeout, + strategy=ProcessingStrategyFactory.get_populating_strategy( + reader_q_key=reader_q_key, + logging_message_q_key=logging_message_q_key, + index_start=instantion_model.populate_worker_settings.index_start, + num_samples=instantion_model.populate_worker_settings.num_samples, + batch_size=instantion_model.populate_worker_settings.batch_size, + ), + process_type=WorkerTypes.POPULATOR, + process_id=0, + logging_message_q_key=logging_message_q_key, + set_stop_event_on_processing_error=True, + stop_event=stop_event, ) - generator = PackedDataGenerator( - components.settings.src_path, - index_path=components.settings.index_path, - tokenizer=components.tokenizer, - eod_token=components.settings.eod_token, - jq_pattern=components.settings.jq_pattern, - number_of_processes=components.settings.num_cpus, - processing_batch_size=components.settings.processing_batch_size, - raw_samples_queue_size=components.settings.raw_samples_queue_size, - processed_samples_queue_size=components.settings.processed_samples_queue_size, + reader_settings = instantion_model.reader_worker_settings.reader_settings + reader_workers = [ + Processor( + in_q=reader_q, + out_qs={tokenizer_q_key: tokenizer_q, logging_message_q_key: logging_message_q}, + in_q_timeout=instantion_model.in_q_timeout, + out_q_timeout=instantion_model.out_q_timeout, + strategy=ProcessingStrategyFactory.get_reader_strategy( + reader_settings, tokenizer_q_key=tokenizer_q_key, logging_message_q_key=logging_message_q_key + ), + process_type=WorkerTypes.READER, + process_id=i, + logging_message_q_key=logging_message_q_key, + set_stop_event_on_processing_error=False, + stop_event=stop_event, + ) + for i in range(instantion_model.reader_worker_settings.num_workers) + ] + + tokenizer_workers = [ + Processor( + in_q=tokenizer_q, + out_qs={writer_q_key: writer_q, logging_message_q_key: logging_message_q}, + in_q_timeout=instantion_model.in_q_timeout, + out_q_timeout=instantion_model.out_q_timeout, + strategy=ProcessingStrategyFactory.get_tokenizer_strategy( + tokenizer_settings=instantion_model.tokenizer_worker_settings.tokenizer_settings, + writer_q_key=writer_q_key, + logging_message_q_key=logging_message_q_key, + ), + process_type=WorkerTypes.TOKENIZER, + process_id=i, + logging_message_q_key=logging_message_q_key, + set_stop_event_on_processing_error=False, + stop_event=stop_event, + ) + for i in range(instantion_model.tokenizer_worker_settings.num_workers) + ] + + writer_worker = Processor( + in_q=writer_q, + out_qs={logging_message_q_key: logging_message_q}, + in_q_timeout=instantion_model.in_q_timeout, + out_q_timeout=instantion_model.out_q_timeout, + strategy=ProcessingStrategyFactory.get_writing_strategy( + ww_settings=instantion_model.writer_worker_settings, logging_message_q_key=logging_message_q_key + ), + process_type=WorkerTypes.WRITER, + process_id=0, + logging_message_q_key=logging_message_q_key, + set_stop_event_on_processing_error=True, + stop_event=stop_event, + ) + + logging_worker = Processor( + in_q=logging_message_q, + out_qs={}, + in_q_timeout=instantion_model.in_q_timeout, + out_q_timeout=instantion_model.out_q_timeout, + strategy=ProcessingStrategyFactory.get_progress_logging_strategy( + logging_interval=instantion_model.logging_worker_settings.logging_interval, + total_num_samples=instantion_model.logging_worker_settings.num_samples, + q_dict={ + reader_q_key: reader_q, + tokenizer_q_key: tokenizer_q, + writer_q_key: writer_q, + logging_message_q_key: logging_message_q, + }, + ), + process_type=WorkerTypes.LOGGING, + process_id=0, + set_stop_event_on_processing_error=False, + stop_event=stop_event, ) - generator.run(components.settings.dst_path) + + pipeline_steps = [ + PipelineStep(name="populating", input_queue=None, processors=[populating_worker], poisonable=False), + PipelineStep(name="reading", input_queue=reader_q, processors=reader_workers, poisonable=True), + PipelineStep(name="tokenizing", input_queue=tokenizer_q, processors=tokenizer_workers, poisonable=True), + PipelineStep(name="writing", input_queue=writer_q, processors=[writer_worker], poisonable=True), + PipelineStep(name="logging", input_queue=logging_message_q, processors=[logging_worker], poisonable=True), + ] + + process_controller = ProcessController(pipeline_steps=pipeline_steps, stop_event=stop_event) + process_controller.run() def merge_packed_data_files(src_paths: list[Path], target_path: Path): diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index edafccae5..16232d17e 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from typing import Annotated, Any, Optional @@ -16,10 +15,10 @@ PydanticPytorchDeviceType, PydanticPytorchModuleType, PydanticTextInferenceComponentType, - PydanticTokenizerIFType, ) from modalities.config.utils import parse_torch_device from modalities.dataloader.dataset import Dataset +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LargeFileLinesReaderTypes from modalities.util import warn_rank_0 @@ -191,20 +190,72 @@ def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationMode return self -class PackedDatasetComponentsInstantiationModel(BaseModel): - class PackedDatasetSettings(BaseModel): - src_path: FilePath - dst_path: Optional[Path] = None - index_path: Optional[FilePath] = None - jq_pattern: str - num_cpus: Annotated[int, Field(strict=True, ge=1)] = os.cpu_count() - eod_token: str - processing_batch_size: Annotated[int, Field(strict=True, ge=1)] - raw_samples_queue_size: Annotated[int, Field(strict=True, ge=1)] - processed_samples_queue_size: Annotated[int, Field(strict=True, ge=1)] - - tokenizer: PydanticTokenizerIFType - settings: PackedDatasetSettings +class TokenizationInstantiationModel(BaseModel): + class PopulateWorkerSettings(BaseModel): + num_samples: Annotated[int, Field(strict=True, ge=1)] + batch_size: Annotated[int, Field(strict=True, ge=1)] + index_start: Optional[Annotated[int, Field(strict=True, ge=0)]] = 0 + + class ReaderWorkerSettings(BaseModel): + class ReaderSettings(BaseModel): + class LocalReaderArgs(BaseModel): + raw_data_path: Path + index_path: Optional[Path] = None + encoding: Optional[str] = "utf-8" + + class GlobalReaderArgs(BaseModel): + global_inorder_index_path: Path + raw_data_file_list_path: Path + raw_data_root_path: Path + global_shuffle_index_path: Optional[Path] = None + encoding: Optional[str] = "utf-8" + + reader_type: LargeFileLinesReaderTypes + reader_args: LocalReaderArgs | GlobalReaderArgs + + num_workers: Annotated[int, Field(strict=True, ge=1)] + reader_settings: ReaderSettings + + class TokenizerWorkerSettings(BaseModel): + class TokenizerSettings(BaseModel): + class TokenizerInstantitionSettings(BaseModel): + tokenizer_component_key: str + tokenizer_variant_key: str + config: dict[str, Any] + + tokenizer_instantiation_settings: TokenizerInstantitionSettings + eod_token: str + jq_pattern: str + + num_workers: Annotated[int, Field(strict=True, ge=1)] + tokenizer_settings: TokenizerSettings + + class WriterWorkerSettings(BaseModel): + dst_path: Path + index_start: Annotated[int, Field(strict=True, ge=0)] + + @field_validator("dst_path") + def ensure_path_does_not_exist(cls, value): + path = Path(value) # Convert to Path object if it's a string + if path.exists(): + raise ValueError(f"The filepath '{path}' already exists.") + return path + + class LoggingWorkerSettings(BaseModel): + logging_interval: Annotated[int, Field(strict=True, ge=1)] + num_samples: Optional[Annotated[int, Field(strict=True, ge=1)]] = None + + paths: dict[str, Path] + populate_worker_settings: PopulateWorkerSettings + reader_worker_settings: ReaderWorkerSettings + tokenizer_worker_settings: TokenizerWorkerSettings + writer_worker_settings: WriterWorkerSettings + logging_worker_settings: LoggingWorkerSettings + reader_q_maxsize: Annotated[int, Field(strict=True, ge=1)] + tokenizer_q_maxsize: Annotated[int, Field(strict=True, ge=1)] + writer_q_maxsize: Annotated[int, Field(strict=True, ge=1)] + in_q_timeout: Annotated[int, Field(strict=True, ge=0)] + out_q_timeout: Annotated[int, Field(strict=True, ge=0)] class TextGenerationInstantiationModel(BaseModel): diff --git a/src/modalities/config/pydanctic_if_types.py b/src/modalities/config/pydanctic_if_types.py index 3761eb8df..4256113ed 100644 --- a/src/modalities/config/pydanctic_if_types.py +++ b/src/modalities/config/pydanctic_if_types.py @@ -14,6 +14,7 @@ from modalities.checkpointing.checkpoint_saving import CheckpointSaving, CheckpointSavingExecutionABC from modalities.checkpointing.checkpoint_saving_strategies import CheckpointSavingStrategyIF from modalities.dataloader.dataloader import LLMDataLoader +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import BaseReaderIF from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss @@ -65,3 +66,4 @@ def __get_pydantic_core_schema__( PydanticTextInferenceComponentType = Annotated[TextInferenceComponent, PydanticThirdPartyTypeIF(TextInferenceComponent)] PydanticGradientClipperIFType = Annotated[GradientClipperIF, PydanticThirdPartyTypeIF(GradientClipperIF)] PydanticModelInitializationIFType = Annotated[ModelInitializationIF, PydanticThirdPartyTypeIF(ModelInitializationIF)] +PydanticBaseReaderIFType = Annotated[BaseReaderIF, PydanticThirdPartyTypeIF(BaseReaderIF)] diff --git a/src/modalities/dataloader/create_packed_data.py b/src/modalities/dataloader/create_packed_data.py deleted file mode 100644 index 5984de182..000000000 --- a/src/modalities/dataloader/create_packed_data.py +++ /dev/null @@ -1,444 +0,0 @@ -import logging -import math -import multiprocessing -import os -import pickle -import traceback -import warnings -from io import BufferedWriter -from pathlib import Path -from typing import Callable, Iterator, Optional - -import jq -import numpy as np -from pydantic import FilePath -from tqdm import tqdm - -from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader -from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper - -logger = logging.getLogger(__name__) - - -class EmptySampleError(RuntimeError): - pass - - -class PackedDataGenerator: - """Reads in a JSONL file and the corresponding index file and packs the dataset for LLM training.""" - - def __init__( - self, - src_path: FilePath, - tokenizer: TokenizerWrapper, - eod_token: str, - number_of_processes: int, - jq_pattern: str, - processing_batch_size: int, - raw_samples_queue_size: int, - processed_samples_queue_size: int, - index_path: Optional[FilePath] = None, - ): - """ - Initializes a PackedDataGenerator object. - - Args: - src_path (FilePath): Path to a JSONL file, which holds text data. - tokenizer (TokenizerWrapper): PretrainedTokenizer object used to tokenize the provided data in `src_path`. - eod_token (str): End-of-document token. - number_of_processes (int): Number of processes used for parallel processing. - jq_pattern (str): jq-pattern applied on every jsonl-entry. Results are afterwards tokenized and packed. - processing_batch_size (int): Size of the batches that the workers process. - raw_samples_queue_size (int): Maximum size of the raw samples queue. - processed_samples_queue_size (int): Maximum size of the processed samples queue. - index_path (Optional[FilePath], optional): Path to an index file, - which indicates the start character position - and length of samples given in `src_path`. If not defined, an index file next to `src_path` is picked, - by replacing its suffix with ".idx". Defaults to None. - - Returns: - None - """ - self.src_path = src_path - self.tokenizer = tokenizer - self.eod_token = eod_token - self._token_size_in_bytes = self._get_required_num_of_bytes_to_repr(self.tokenizer.vocab_size) - eod_token_id = self.tokenizer.get_token_id(self.eod_token) - self._encoded_eod_token_as_bytes = self._encoded_token_to_bytes(eod_token_id) - self.jq_filter = jq.compile(jq_pattern) - self._number_of_processes = number_of_processes - self._reader = LargeFileLinesReader(src_path, index_path=index_path) # reads string with utf-8 encoding - self._total_num_of_tokens = 0 - self._raw_samples_queue = multiprocessing.Queue(maxsize=raw_samples_queue_size) - self.processed_samples_queue = multiprocessing.Queue(maxsize=processed_samples_queue_size) - self._exception_buffer = [] - self.processing_batch_size = processing_batch_size - - @staticmethod - def _get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int: - """ - Calculates the required number of bytes to represent an integer. - - Args: - int_to_get_repr (int): The integer to get the representation for. - - Returns: - int: The number of bytes required to represent the integer. - """ - # we currently only support token sizes of 1, 2 and 4 bytes, as implemented here: - # https://github.com/Modalities/modalities/blob/fix_char_bytes_indexation_mismatch/src/modalities/dataloader/dataset.py#L202 - num_bytes = math.ceil(math.log2(int_to_get_repr) / 8) - if num_bytes == 1: - return 1 - elif num_bytes == 2: - return 2 - elif num_bytes <= 4: - return 4 - else: - raise ValueError("Currently only support token byte sizes of 1, 2, and 4.") - - def _encoded_token_to_bytes(self, encoded_token: int) -> bytes: - """ - Converts an encoded token to its byte representaion. - - Args: - encoded_token (int): The encoded token to be converted. - - Returns: - bytes: The byte representation of the token. - - """ - return encoded_token.to_bytes(self._token_size_in_bytes, byteorder="little", signed=False) - - def _default_destination_path(self, destination_path: Optional[Path] = None) -> Path: - """ - Returns the default destination path for the packed data. - - Args: - destination_path (Path, optional): The specific destination path. Defaults to None. - - Returns: - Path: The default destination path for the packed data. - """ - if destination_path is None: - default_destination_path = Path(self.src_path.parent, f"{self.src_path.stem}.pbin") - print( - f"No specific Destination Path provided. " - f"Pointing to destination next to input data at: {default_destination_path}" - ) - return default_destination_path - return Path(destination_path) - - def run(self, dst_path: Optional[Path] = None): - """ - Packs data and saves it to (default) dst_path. - - Args: - dst_path (Optional[Path]): The destination path to save the packed data. - If not provided, a default destination path will be used. - - Raises: - ValueError: If the file already exists at the destination path. - Exception: If an exception occurs during the data packing process. - - Returns: - None - """ - assert self._total_num_of_tokens == 0, f"This {self.__name__} was already used and is exhausted. Use another!" - dst_path = self._default_destination_path(destination_path=dst_path) - - dst_path.parent.mkdir(parents=True, exist_ok=True) - if dst_path.exists(): - raise ValueError(f"file already exists at destination path '{dst_path}'.") - - self._exception_buffer = [] - try: - # not setting this can cause deadlocks when using hf's "FastTokenizers". See also: - # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning/67254879#67254879 - os.environ["TOKENIZERS_PARALLELISM"] = "false" - self._launch_parallelized_workers(dst_path) - finally: - os.unsetenv("TOKENIZERS_PARALLELISM") - - if self._exception_buffer: - raise self._exception_buffer[0] - - def _launch_parallelized_workers(self, dst_path: Path): - # Launches workers in parallel for reading, writing, and processing data. - # The data is stored in the provided destination path. - - reader = multiprocessing.Process(target=self._reader_thread()) - reader.start() - - writer = multiprocessing.Process(target=self._writer_thread(dst_path)) - writer.start() - processor_threads = [ - multiprocessing.Process(target=self._process_thread, args=(i,)) for i in range(self._number_of_processes) - ] - for p in processor_threads: - p.start() - for p in processor_threads: - p.join() - self._stop_processing() - writer.join() - - def _stop_processing(self): - # Stops the processing of samples by putting None in the processed_samples_queue. - self.processed_samples_queue.put(None) - - def _generator_for_tokens_to_get_written(self): - # Generator function that yields batches of processed samples. - - while True: - if self._check_for_parallel_errors(): - return - batch = self.processed_samples_queue.get() - if batch is None: - break - yield batch - - def _check_for_parallel_errors(self) -> bool: - # Checks if there are any errors in the exception buffer. - return bool(self._exception_buffer) - - def _writer_thread(self, dst_path: Path) -> Callable: - # Returns a callable writer function that writes a batch - # received from the processed_samples_queue to the destination file. - - def writer(): - # writes a batch received from the processed_samples_queue to the destination file - def _write_batch( - batch: list[tuple[int, bytes]], prev_line_id: int, curr_offset: int, index_list: list, f: BufferedWriter - ) -> tuple[int, int]: - # write the tokens for each document - for line_id, tokens_as_bytes in batch: - if prev_line_id + 1 != line_id: - raise ValueError( - f"Line IDs are not consecutive. Expected {prev_line_id + 1}, but got {line_id}" - ) - f.write(tokens_as_bytes) - segment_length = len(tokens_as_bytes) - index_list.append((curr_offset, segment_length)) - curr_offset += segment_length - prev_line_id = line_id - return prev_line_id, curr_offset - - index_list = [] - with dst_path.open("wb") as f: - # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) - # not possible to prepend header after determining size of data section - f.write((0).to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) - f.write( - self._token_size_in_bytes.to_bytes( - EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little" - ) - ) - # The offset only applies to the data section, not the header - # When we load the file, we add the header size to the offset - curr_offset = 0 - - # write data section (tokens) - pbar = tqdm(total=len(self._reader), desc="Processed batches") - prev_line_id = -1 - batch_dict = {} - for batch in self._generator_for_tokens_to_get_written(): - line_id = batch[0][0] - batch_dict[line_id] = batch - - while prev_line_id + 1 in batch_dict: - batch = batch_dict.pop(prev_line_id + 1) - prev_line_id, curr_offset = _write_batch(batch, prev_line_id, curr_offset, index_list, f) - pbar.update(len(batch)) - # write index - f.write(pickle.dumps(index_list)) - - self._update_data_length_in_pre_allocated_header(dst_path, index_list) - - return writer - - def _reader_thread(self) -> Callable: - # returns a reader function that reads lines from the reader and puts them into a queue. - def reader(): - batch = [] - for line_id, line in tqdm(enumerate(self._reader), desc="Reading jsonl", disable=True): - batch.append((line_id, line)) - if len(batch) % self.processing_batch_size == 0: - self._raw_samples_queue.put(batch) - batch = [] - - # add the remaining samples - if len(batch) > 0: - self._raw_samples_queue.put(batch) - - for _ in range(self._number_of_processes): - self._raw_samples_queue.put(None) - - return reader - - def _process_thread(self, process_id: int): - # Process the lines in a batch and put the processed samples into the processed_samples_queue. - if self._check_for_parallel_errors(): - return - - while True: - if self._check_for_parallel_errors(): - return - batch = self._raw_samples_queue.get() - if batch is None: - break - - try: - batch_processed = [] - for line_id, line in batch: - processed_line = self._process_line(line, process_id) - batch_processed.append((line_id, processed_line)) - self.processed_samples_queue.put(batch_processed) - except EmptySampleError: - warnings.warn( - f"Encountered empty sample in line {line_id} of file {self.src_path} within process {process_id}" - ) - except Exception as exception: - warnings.warn( - f"Could not process line {line_id} in {self.src_path} within process {process_id}. " - f"Raised the following error: {exception=}" - ) - traceback.print_exc() - - def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: list[tuple[int, int]]): - # Update the length of the data section in the pre-allocated header of the destination file. - # The data segment length is sum of the starting position and the length of the last document. - length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1] - data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes( - EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little" - ) - with dst_path.open("rb+") as fout: - fout.seek(0) - fout.write(data_section_length_in_bytes) - - def _process_line(self, line: str, process_id: int) -> bytes: - # extracts the text via the jq_filter and applies tokenization to the extract text - jq_retrieved_text = self.jq_filter.input_text(line).first() - if jq_retrieved_text is None: - raise ValueError(f"jq was not able to find anything using the expression: {self.jq_filter}") - tokens = self.tokenizer.tokenize(jq_retrieved_text) - if len(tokens) == 0: - raise EmptySampleError("Received empty sample...") - token_byte_string = b"".join(map(self._encoded_token_to_bytes, tokens)) - if not token_byte_string.endswith(self._encoded_eod_token_as_bytes): - token_byte_string = token_byte_string + self._encoded_eod_token_as_bytes - return token_byte_string - - -class EmbeddedStreamData: - # amount of bytes to represent number of all tokens in dataset. - # If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation. - # Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides - DATA_SECTION_LENGTH_IN_BYTES = 8 - TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4 - HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES - - def __init__(self, data_path: Path, load_index: Optional[bool] = True): - """ - Initializes an EmbeddedStreamData object. - - Args: - data_path (Path): The path to the packed data file. - load_index (bool, optional): Whether to load the index. Defaults to True. - - Raises: - FileNotFoundError: If the packed data file is not found at the specified path. - - """ - self._data_path = data_path - if not self._data_path.is_file(): - raise FileNotFoundError( - f"Packed Data was not found at {self._data_path.absolute()}." - f"Create on in advance by using `modalities data pack_encoded_data`." - ) - - with self._data_path.open("rb") as f: - # get number of bytes in data section - data_section_length_in_bytes = f.read(self.DATA_SECTION_LENGTH_IN_BYTES) - self.data_len = int.from_bytes(data_section_length_in_bytes, byteorder="little") - - # get number of bytes for encoding a single token - f.seek(self.DATA_SECTION_LENGTH_IN_BYTES) - token_size_as_bytes = f.read(self.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES) - self.token_size_in_bytes = int.from_bytes(token_size_as_bytes, byteorder="little", signed=False) - - # get index - if load_index: - f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len) - pkl_encoded_index = f.read() - # contains the start offset and length of each segment - # as byte positions in the data section - self._index_base: list[tuple[int, int]] = pickle.loads(pkl_encoded_index) - else: - self._index_base = None - - # initialize memmapped data section - self._data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,)) - - @property - def index_base(self) -> list[tuple[int, int]]: - if self._index_base is None: - raise ValueError("Index was not loaded. Set `load_index=True` during initialization.") - return self._index_base - - @property - def data(self) -> np.ndarray: - return self._data - - -def join_embedded_stream_data(stream_data: list[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): - """ - Joins the embedded stream data into a single file. - - Args: - stream_data (list[EmbeddedStreamData]): A list of EmbeddedStreamData objects representing the stream data. - target_file (Path): The target file to write the joined data to. - chunk_size (int, optional): The size of each data chunk. Defaults to 2048. - - Raises: - FileExistsError: If the target file already exists. - - Returns: - None - """ - if target_file.exists(): - raise FileExistsError(f'Target File at "{target_file}" exists!') - data_len = sum(d.data_len for d in stream_data) - assert len({d.token_size_in_bytes for d in stream_data}) == 1, ( - "Found different token representation sizes. This could indicate the usage of different tokenizers. " - "Not supported!" - ) - token_size_in_bytes = stream_data[0].token_size_in_bytes - - num_data_chunks = sum(math.ceil(d.data_len / chunk_size) for d in stream_data) - data_stream_generator = (d.data[i : i + chunk_size] for d in stream_data for i in range(0, d.data_len, chunk_size)) - - num_entries = sum(len(d.index_base) for d in stream_data) - - def index_stream_generator() -> Iterator[tuple[int, int]]: - # generates a stream of index offsets and segment lengths. - curr_offset = 0 - for embedded_stream_data in stream_data: - for entry_offset, segment_length in embedded_stream_data.index_base: - yield entry_offset + curr_offset, segment_length - curr_offset += embedded_stream_data.data_len - curr_offset -= embedded_stream_data.HEADER_SIZE_IN_BYTES - - with target_file.open("wb") as fout: - fout.write(data_len.to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) - fout.write( - token_size_in_bytes.to_bytes(EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little") - ) - for data_chunk in tqdm(data_stream_generator, total=num_data_chunks, desc="Writing Data Chunks..."): - fout.write(data_chunk) - - joint_index = [entry for entry in tqdm(index_stream_generator(), total=num_entries, desc="Concatenating Index")] - pickled_index = pickle.dumps(joint_index) - pickled_index_as_chunks = (pickled_index[i : i + chunk_size] for i in range(0, len(pickled_index), chunk_size)) - num_index_chunks = math.ceil(len(pickled_index) / chunk_size) - for index_chunk in tqdm(pickled_index_as_chunks, total=num_index_chunks, desc="Writing Index Chunks..."): - fout.write(index_chunk) diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 11ba3d2dd..577d459fd 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -5,6 +5,8 @@ from typing import Optional import jq +from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import EmbeddedStreamData +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader import numpy as np from pydantic import BaseModel from torch.utils.data.dataset import Dataset as TorchdataSet @@ -13,8 +15,6 @@ from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper -from ..dataloader.large_file_lines_reader import LargeFileLinesReader -from .create_packed_data import EmbeddedStreamData class Dataset(TorchdataSet): @@ -163,7 +163,7 @@ def __init__( """ super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) - self.reader = LargeFileLinesReader(self.raw_data_path, index_path=index_path) + self.reader = LocalLargeFileLinesReader(self.raw_data_path, index_path=index_path) self.jq_filter = jq.compile(jq_pattern) self.tokenizer = tokenizer diff --git a/src/modalities/dataloader/large_file_lines_reader.py b/src/modalities/dataloader/large_file_lines_reader.py deleted file mode 100644 index 6488cdd1b..000000000 --- a/src/modalities/dataloader/large_file_lines_reader.py +++ /dev/null @@ -1,130 +0,0 @@ -import mmap -import pickle -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Optional - - -class BaseReader(ABC): - @abstractmethod - def __len__(self) -> int: - raise NotImplementedError - - @abstractmethod - def __getitem__(self, key: int) -> str | list[str]: - raise NotImplementedError - - -class LargeFileLinesReader(BaseReader): - """LargeFileLinesReader class that read lines from a large file efficiently.""" - - def __init__( - self, - raw_data_path: Path, - index_path: Optional[Path] = None, - encoding: Optional[str] = "utf-8", - use_sample_length_from_index: bool = True, - ): - """ - Initializes a LargeFileLinesReader object. - - Args: - raw_data_path (Path): Path to a jsonl file, which holds text data. - index_path (Optional[Path]): Path to an index file, which indicates the start character/byte position - and length of samples given in `raw_data_path`. - If not defined, an index next to `raw_data_path` is picked, - by replacing its suffix with ".idx". - encoding (Optional[str]): The encoding of the file (default: "utf-8"). - If encoding is None, the raw data is read as bytes. - use_sample_length_from_index (bool): If True, the sample length is taken from the index file - i.e., the (offset, sample_length) pairs. If False, the sample length is calculated - as the difference between the starting point of the next and the current sample. - Returns: - None - """ - self.encoding = encoding - self.raw_data_path = raw_data_path - self.index_path = self.default_index_path(self.raw_data_path, index_path) - self.use_sample_length_from_index = use_sample_length_from_index - - if not self.raw_data_path.is_file(): - raise FileNotFoundError("Raw data file does not exist") - if not self.index_path.is_file(): - raise FileNotFoundError("Index file does not exist. Use `modalities data create_raw_index` to create one.") - - with self.index_path.open("rb") as f: - self.index = pickle.load(f) - - self.raw_data_fd = self.raw_data_path.open("rb") - self.mmapped_data_file = mmap.mmap(self.raw_data_fd.fileno(), 0, access=mmap.ACCESS_READ) - - def close(self): - self.mmapped_data_file.close() - self.raw_data_fd.close() - - @staticmethod - def default_index_path(raw_data_path: Path, index_path: Optional[Path] = None) -> Path: - """ - Returns the default index path for the given raw data path. - - Args: - raw_data_path (Path): The path to the raw data file. - index_path (Optional[Path]): The path to the index file (default: None). - - Returns: - Path: The default index path. - - Note: - If `index_path` is not provided, the default index path is generated by - appending the extension ".idx" to the stem of the `raw_data_path`. - """ - if index_path is None: - default_index_path = Path(raw_data_path.parent, f"{raw_data_path.stem}.idx") - print(f"No specific Index Path provided. Pointing to index next to input data at: {default_index_path}") - return default_index_path - return index_path - - def __len__(self) -> int: - """ - Returns the length of the index. - - Returns: - int: The length of the index. - """ - return len(self.index) - - def __getitem__(self, key: int) -> str | bytes: - """ - Retrieves an item from the LargeFileLinesReader. - - Args: - key (int): The index used to retrieve the item. - - Returns: - str | bytes: The item retrieved from the LargeFileLinesReader. - - Raises: - IndexError: If the key is out of range. - - """ - - offset, sample_length_in_bytes = self.index[key] - - # If use_sample_length_from_index = False, we calculate the sample length as the difference between the - # starting point of the next and the current sample. - # This allows for reading in the entire sample including the newline character. - if not self.use_sample_length_from_index: - if key + 1 < len(self.index): - sample_length_in_bytes = self.index[key + 1][0] - self.index[key][0] - else: - sample_length_in_bytes = len(self.mmapped_data_file) - offset - - return self._read_from_raw_file(offset, sample_length_in_bytes) - - def _read_from_raw_file(self, offset: int, sample_length_in_bytes: int) -> str | bytes: - # Reads a specified number of bytes from a raw file starting from a given offset. - data = self.mmapped_data_file[offset : offset + sample_length_in_bytes] - if self.encoding is not None: - data_decoded = data.decode(self.encoding) - return data_decoded - return data diff --git a/src/modalities/dataloader/preprocessing/__init__.py b/src/modalities/dataloader/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/dataloader/preprocessing/indexation/__init__.py b/src/modalities/dataloader/preprocessing/indexation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/dataloader/preprocessing/indexation/global_indexation.py b/src/modalities/dataloader/preprocessing/indexation/global_indexation.py new file mode 100644 index 000000000..8add0f353 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/indexation/global_indexation.py @@ -0,0 +1,112 @@ +import pickle +from pathlib import Path + +import numpy as np +import tqdm + + +def _get_global_index_file_path(global_index_root_path: Path) -> Path: + global_index_file_path = global_index_root_path / f"{global_index_root_path.name}_inorder.idx" + return global_index_file_path + + +def _get_file_list(file_list_path: Path) -> list[Path]: + file_list: list[Path] = [] + with open(file_list_path, "r") as f: + for line in f: + file_list.append(Path(line.strip())) + return file_list + + +def _get_file_id_file_path_mappings(file_list: list[Path]) -> tuple[dict[Path, int], dict[int, Path]]: + file_path_to_id = {file_path.with_suffix(""): i for i, file_path in enumerate(file_list)} + id_to_file_path = {i: file_path.with_suffix("") for i, file_path in enumerate(file_list)} + return file_path_to_id, id_to_file_path + + +def _get_local_index_paths(file_list: list[Path], root_index_path: Path, global_index_root_path: Path) -> list[Path]: + local_index_paths = [ + path.with_suffix(".idx") + for path in file_list + if (root_index_path / path).is_relative_to(global_index_root_path) + ] + return local_index_paths + + +def _get_total_num_documents(local_index_paths: list[Path], root_index_path: Path) -> int: + num_documents = 0 + for local_index_path in tqdm.tqdm(local_index_paths, desc="Counting total number of documents"): + with open(root_index_path / local_index_path, "rb") as f: + index = pickle.load(f) + num_documents += len(index) + return num_documents + + +def _populate_global_index_array( + global_index_file_path: Path, + num_documents: int, + local_index_paths: list[Path], + root_index_path: Path, + file_path_to_id: dict[Path, int], +) -> np.memmap: + shape = (num_documents + 1, 3) + global_index_array = np.memmap(global_index_file_path, dtype="int64", mode="w+", shape=shape) + + # the first row is reserved for the shape of the array and whether rows are shuffled. + # + global_index_array[0] = np.array([*shape, 0]) + start_index = 1 + for local_index_path in tqdm.tqdm(local_index_paths, desc="Populating global index array"): + with open(root_index_path / local_index_path, "rb") as f: + local_index = pickle.load(f) + + local_index_array = np.array(local_index) + # add the file id to the local index + file_id = file_path_to_id[local_index_path.with_suffix("")] + local_index_array = np.insert(local_index_array, 0, file_id, axis=1) + + global_index_array[start_index : start_index + len(local_index_array)] = local_index_array + start_index += len(local_index_array) + global_index_array.flush() + return global_index_array + + +def create_global_index(file_list_path: Path, root_index_path: Path, global_index_root_path: Path) -> Path: + global_index_file_path = _get_global_index_file_path(global_index_root_path) + + file_list = _get_file_list(file_list_path) + + file_path_to_id, _ = _get_file_id_file_path_mappings(file_list) + local_index_paths = _get_local_index_paths(file_list, root_index_path, global_index_root_path) + num_documents = _get_total_num_documents(local_index_paths, root_index_path) + + _populate_global_index_array( + global_index_file_path, num_documents, local_index_paths, root_index_path, file_path_to_id + ) + return global_index_file_path + + +def create_shuffled_global_index(global_index_file_path: Path) -> Path: + global_shuffled_index_file_path = ( + global_index_file_path.parent / f"{global_index_file_path.stem.replace('inorder', 'shuffle_index')}.idx" + ) + print(global_shuffled_index_file_path) + + # global index array + num_rows, _, _ = np.memmap(global_index_file_path, dtype="int64", mode="r")[0:3] + + print(f"Shuffling {num_rows-1} global index indices") + # we count from 1 since the 0th row contains meta information (num_rows, num_cols, is_shuffled) + indices = np.arange(1, num_rows) + np.random.shuffle(indices) + + print(f"Writing out shuffled global index array with {num_rows} elements") + global_shuffled_index_array = np.memmap( + global_shuffled_index_file_path, dtype="int64", mode="w+", shape=(len(indices),) + ) + chunk_size = 10 + for i in tqdm.tqdm(range(0, len(indices), chunk_size)): + chunk_indices = indices[i : i + chunk_size] + global_shuffled_index_array[i : i + len(chunk_indices)] = chunk_indices + global_shuffled_index_array.flush() + return global_shuffled_index_file_path diff --git a/src/modalities/dataloader/create_index.py b/src/modalities/dataloader/preprocessing/indexation/local_indexation.py similarity index 61% rename from src/modalities/dataloader/create_index.py rename to src/modalities/dataloader/preprocessing/indexation/local_indexation.py index 55e573e15..d4b5469ca 100644 --- a/src/modalities/dataloader/create_index.py +++ b/src/modalities/dataloader/preprocessing/indexation/local_indexation.py @@ -1,16 +1,18 @@ -import json import os import pickle as pkl import queue import threading -import warnings +import time from pathlib import Path +import jq from tqdm import tqdm +from modalities.utils.logging import get_logger + class IndexGenerator: - def __init__(self, src_file: Path, drop_faulty_entries: bool = False): + def __init__(self, src_file: Path, drop_faulty_entries: bool = False, jq_pattern: str = ".text"): """ Initializes an IndexGenerator object. Reads a JSONL file as a binary file, and iterates through it character by character. @@ -26,6 +28,7 @@ def __init__(self, src_file: Path, drop_faulty_entries: bool = False): None """ self.src_file = src_file + self.jq_pattern = jq_pattern self.drop_faulty_entries = drop_faulty_entries with self.src_file.open(mode="rb") as fin: # Move the cursor to the end of the file @@ -50,6 +53,7 @@ def create_index(self, target_path_for_index_file: Path): Returns: None """ + start_time = time.time() self._exception_buffer = [] reader = threading.Thread(target=self._reader_thread) reader.start() @@ -58,9 +62,21 @@ def create_index(self, target_path_for_index_file: Path): reader.join() processor.join() if self._exception_buffer: + get_logger(name="main").warning( + f"Index creation failed for {target_path_for_index_file}. Exception buffer: {self._exception_buffer}" + ) raise self._exception_buffer[0] - print(f"Created index of length {len(self._index_map)}") - target_path_for_index_file.write_bytes(pkl.dumps(self._index_map)) + + if len(self._index_map) == 0: + get_logger(name="main").warning(f"Could not create index! No entries found in {self.src_file}") + else: + end_time = time.time() + get_logger(name="main").info( + f"Created index {target_path_for_index_file} of length {len(self._index_map)} " + f"at {len(self._index_map) / (end_time - start_time)} iterations/s." + ) + target_path_for_index_file.write_bytes(pkl.dumps(self._index_map)) + get_logger(name="main").info(f"Wrote index {target_path_for_index_file} to disc.") def _indexer_thread(self): # This method is responsible for indexing the lines in the queue and parsing them as JSON. @@ -78,33 +94,44 @@ def queue_generator(): break yield line - def parse_line_as_json(line_start_idx: int, line: str): + def parse_line_as_json(line_id: int, line_start_byte_pos: int, line: bytes, jq_filter): # Parses a line as JSON and appends the sample index, i.e., # the line start index and length to the index map. # If the line is faulty and `drop_faulty_entries` is set to True, a warning is issued. - try: # check if line is a valid json - json.loads(line) - self._index_map.append((line_start_idx, len(line))) - except Exception as low_level_err: + line_string = line.decode("utf-8") + jq_retrieved_text = jq_filter.input_text(line_string).first() + if jq_retrieved_text is not None: + if len(jq_retrieved_text) > 0: + self._index_map.append((line_start_byte_pos, len(line))) + else: + get_logger(name="main").warning( + f"Faulty line {line_id} (no text) in {str(self.src_file)}, skipping..." + ) + else: if self.drop_faulty_entries: - warnings.warn(f'faulty line "{line}", skipping...') + get_logger(name="main").warning( + f"Faulty line {line_id} (parsing error) in {str(self.src_file)}, skipping..." + ) else: - err = ValueError(f'faulty line "{line}", skipping...') - err.__cause__ = low_level_err + get_logger(name="main").warning(f"Faulty line {line_id} (parsing error), stopping...") + err = ValueError(f'Faulty line "{line} in {str(self.src_file)}') self._exception_buffer.append(err) + jq_filter = jq.compile(self.jq_pattern) self._index_map = [] - for line_start_idx, line in tqdm(queue_generator(), desc="Processed Lines"): + for line_id, line_start_byte_pos, line in tqdm(queue_generator(), desc="Processed Lines", disable=True): if self._check_for_parallel_errors(): return - parse_line_as_json(line_start_idx, line) + parse_line_as_json(line_id, line_start_byte_pos, line, jq_filter) def _reader_thread(self): # Reads lines from the source file and puts them into a queue. # This method is executed in a separate thread. It reads lines from the source file until # the end of the file is reached. Each line is put into a queue along with its cursor position. If any # errors are detected, the method returns immediately. + get_logger(name="main").info(f"Reading the jsonl file {self.src_file}...") + num_read_documents = 0 with open(self.src_file, "rb") as fin: while True: cursor = fin.tell() @@ -114,10 +141,13 @@ def _reader_thread(self): if fin.tell() == self._total_num_bytes: if line.endswith(b"\n"): line = line[:-1] - self._queue_of_raw_lines.put((cursor, line)) + self._queue_of_raw_lines.put((num_read_documents, cursor, line)) + num_read_documents += 1 break line_without_newline_char = line[:-1] - self._queue_of_raw_lines.put((cursor, line_without_newline_char)) + self._queue_of_raw_lines.put((num_read_documents, cursor, line_without_newline_char)) + num_read_documents += 1 + get_logger(name="main").info(f"Finished reading the jsonl file {self.src_file} (read {num_read_documents}).") self._queue_of_raw_lines.put(None) def _check_for_parallel_errors(self) -> bool: diff --git a/src/modalities/dataloader/preprocessing/queued_processing/__init__.py b/src/modalities/dataloader/preprocessing/queued_processing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py b/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py new file mode 100644 index 000000000..f5ff38a1a --- /dev/null +++ b/src/modalities/dataloader/preprocessing/queued_processing/process_controller.py @@ -0,0 +1,83 @@ +import multiprocessing as mp +from dataclasses import dataclass +from multiprocessing.synchronize import Event + +import tqdm + +from modalities.dataloader.preprocessing.queued_processing.processors import Processor +from modalities.utils.logging import get_logger + + +@dataclass +class PipelineStep: + name: str + poisonable: bool + input_queue: mp.Queue + processors: list[Processor] + + +class ProcessController: + def __init__(self, pipeline_steps: list[PipelineStep], stop_event: Event, join_timeout: int = 5): + """Initializes the ProcessController + Each pipeline step contains a list of processors that retrieve the data from the input queue, + process it and if necessary put it into the output queue of the next step. + """ + self._pipeline_steps = pipeline_steps + self._stop_event = stop_event + self._join_timeout = join_timeout + + def join_processors_in_step(self, step: PipelineStep): + """Joins the processors of a pipeline step + If the stop_event is set, the processors are terminated + """ + # poison the input queues of the processors + if step.poisonable: + for _ in tqdm.tqdm(step.processors, desc=f"Poisoning {step.name} processes"): + if step.input_queue is not None: + step.input_queue.put(None) + + # join the processors + num_exits = 0 + while num_exits < len(step.processors): + processor = step.processors[num_exits] + + # if the processor is not alive, we continue with the next one + if not processor.is_alive(): + get_logger().info(f"Processor {processor.full_name} is not alive. Continuing with the next processor.") + num_exits += 1 + continue + # if the stop event is set, we terminate the processor + if self._stop_event.is_set(): + try: + processor.terminate() + except Exception as e: + # if we can't terminate the processor, we continue with the next one + get_logger().error( + f"Error while terminating processor {processor.full_name}: {e}. " + "Continuing with the next processor." + ) + num_exits += 1 + continue + get_logger().info(f"Terminated processor {processor.full_name}") + num_exits += 1 + # if the stop event is not set, we join the processor + else: + get_logger().info(f"Joining {processor.full_name} ...") + processor.join(timeout=self._join_timeout) + if processor.exitcode is None: + get_logger().info(f"Joining {processor.full_name} timed out. Exit code: {processor.exitcode} ...") + continue + get_logger().info(f"Joined processor {processor.full_name}. Exit code: {processor.exitcode}") + num_exits += 1 + + def run(self): + # start the processors + for step in self._pipeline_steps: + get_logger().info(f"Starting processors for step {step.name}") + for processor in step.processors: + processor.start() + + # wait for the processors to finish + for step in self._pipeline_steps: + get_logger().info(f"Stopping {step.name} processes...") + self.join_processors_in_step(step) diff --git a/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py b/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py new file mode 100644 index 000000000..e2b29aec0 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/queued_processing/processing_strategy_if.py @@ -0,0 +1,16 @@ +from abc import ABC +from typing import Any, Optional + + +class ProcessingStrategyIF(ABC): + def process(self, item: Optional[Any] = None) -> dict[str, Any]: + raise NotImplementedError + + def __enter__(self): + raise NotImplementedError + + def finalize(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_value, traceback): + raise NotImplementedError diff --git a/src/modalities/dataloader/preprocessing/queued_processing/processors.py b/src/modalities/dataloader/preprocessing/queued_processing/processors.py new file mode 100644 index 000000000..39ed11b4a --- /dev/null +++ b/src/modalities/dataloader/preprocessing/queued_processing/processors.py @@ -0,0 +1,137 @@ +import multiprocessing as mp +import queue +import traceback +from multiprocessing.synchronize import Event +from typing import Any, Optional + +from modalities.dataloader.preprocessing.queued_processing.processing_strategy_if import ProcessingStrategyIF +from modalities.exceptions import ProcessingStrategyDoneException, ProcessorException, ProcessorStopEventException +from modalities.utils.logging import get_logger + + +class QueueConsumer: + def __init__(self, in_q: mp.Queue, in_q_timeout: int): + self._in_q = in_q + self._in_q_timeout = in_q_timeout + self._consumed_items = 0 + + def get_item(self, stop_event: Event) -> Any: + while not stop_event.is_set(): + try: + item = self._in_q.get(timeout=self._in_q_timeout) + except queue.Empty: + continue + if item is None: + pass + self._consumed_items += 1 + return item + raise ProcessorStopEventException("Stop event was set") + + +class QueueProducer: + def __init__(self, out_q: mp.Queue, out_q_timeout: int): + self._out_q = out_q + self._out_q_timeout = out_q_timeout + + def put_item(self, item: Any, stop_event: Event): + while not stop_event.is_set(): + try: + self._out_q.put(item, timeout=self._out_q_timeout) + except queue.Full: + continue + return + raise ProcessorStopEventException("Stop event was set") + + +class Processor(mp.Process): + def __init__( + self, + out_qs: dict[str, mp.Queue], + in_q_timeout: int, + out_q_timeout: int, + strategy: ProcessingStrategyIF, + process_id: str, + process_type: str, + stop_event: Event, + set_stop_event_on_processing_error: bool, + in_q: mp.Queue = None, + logging_message_q_key: Optional[str] = None, + ): + super().__init__() + + self._consumer = QueueConsumer(in_q, in_q_timeout) if in_q is not None else None + self._producers: dict[str, QueueProducer] = { + q_key: QueueProducer(out_q, out_q_timeout) for q_key, out_q in out_qs.items() + } + self._strategy = strategy + self._stop_event = stop_event + self._process_type = process_type + self._process_id = process_id + self.exit_on_processing_error = set_stop_event_on_processing_error + self._logging_message_q_key = logging_message_q_key + # if the consumer is None, we are the first processor in the pipeline and we need to generate the items + self._processing_fun = self._generate_item if self._consumer is None else self._process_item + + @property + def process_id(self) -> str: + return self._process_id + + @property + def process_type(self) -> str: + return self._process_type + + @property + def full_name(self) -> str: + return f"{self._process_type}:{self._process_id}" + + def _generate_item(self): + try: + processed_sub_items: dict[str, Any] = self._strategy.process() + except ProcessingStrategyDoneException as e: + self._strategy.finalize() + get_logger().info(f"{self.full_name} received done (iterator exhausted). Exiting...") + raise e + self._forward_sub_items(processed_sub_items) + + def _process_item(self): + item = self._consumer.get_item(stop_event=self._stop_event) + if item is None: + self._strategy.finalize() + raise ProcessingStrategyDoneException(f"{self.full_name} received done (poison pill).") + # process the item + try: + processed_sub_items: dict[str, Any] | None = self._strategy.process(item) + except Exception as e: + get_logger().error(f"{self.full_name} failed to process item {item}. Error: {e}") + if self.exit_on_processing_error: + raise ProcessorException(f"{self.full_name} failed to process item {item}.") from e + return # continue with the next item + # forward the processed sub items to the respective queues + self._forward_sub_items(processed_sub_items) + + def _forward_sub_items(self, processed_sub_items: dict[str, Any]): + # place the processed sub items in the correct out queues + for destination_q_key, processed_sub_item in processed_sub_items.items(): + if destination_q_key == self._logging_message_q_key: + processed_sub_item.process_id = self._process_id + processed_sub_item.process_type = self._process_type + self._producers[destination_q_key].put_item(processed_sub_item, stop_event=self._stop_event) + + def run(self): + try: + with self._strategy: + while True: + self._processing_fun() + except ProcessingStrategyDoneException: + pass + except ProcessorStopEventException: + # if the stop event was set, some process in the pipeline failed and we need to exit + get_logger().info(f"{self.full_name} received forced stop event. Exiting...") + except Exception as e: + # in this block, every exception comes from this very process and we need to set the stop event + # to signal the other processes of the pipeline that something went wrong + stacktrace = traceback.format_exc() + get_logger().error(f"Stacktrace for {self.full_name} : {stacktrace}") + get_logger().error(f"{self.full_name} failed with error: {e}, setting stop event") + self._stop_event.set() + get_logger().error(f"{self.full_name} exiting...") diff --git a/src/modalities/dataloader/preprocessing/queued_processing/queue_items.py b/src/modalities/dataloader/preprocessing/queued_processing/queue_items.py new file mode 100644 index 000000000..3f99d39a4 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/queued_processing/queue_items.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + + +@dataclass +class ReadingJob: + sample_id: int + batch_size: int + + +@dataclass +class ProgressMessage: + worker_type: Enum + num_samples: int + process_type: Optional[str] = None + process_id: Optional[str] = None diff --git a/src/modalities/dataloader/preprocessing/tokenization/__init__.py b/src/modalities/dataloader/preprocessing/tokenization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/dataloader/preprocessing/tokenization/embedded_stream_data.py b/src/modalities/dataloader/preprocessing/tokenization/embedded_stream_data.py new file mode 100644 index 000000000..c7bb74d12 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/embedded_stream_data.py @@ -0,0 +1,123 @@ +import math +import pickle +from pathlib import Path +from typing import Iterator, Optional + + +import numpy as np +from tqdm import tqdm + + +class EmbeddedStreamData: + # amount of bytes to represent number of all tokens in dataset. + # If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation. + # Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides + DATA_SECTION_LENGTH_IN_BYTES = 8 + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4 + HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES + + def __init__(self, data_path: Path, load_index: Optional[bool] = True): + """ + Initializes an EmbeddedStreamData object. + + Args: + data_path (Path): The path to the packed data file. + load_index (bool, optional): Whether to load the index. Defaults to True. + + Raises: + FileNotFoundError: If the packed data file is not found at the specified path. + + """ + self._data_path = data_path + if not self._data_path.is_file(): + raise FileNotFoundError( + f"Packed Data was not found at {self._data_path.absolute()}." + f"Create on in advance by using `modalities data pack_encoded_data`." + ) + + with self._data_path.open("rb") as f: + # get number of bytes in data section + data_section_length_in_bytes = f.read(self.DATA_SECTION_LENGTH_IN_BYTES) + self.data_len = int.from_bytes(data_section_length_in_bytes, byteorder="little") + + # get number of bytes for encoding a single token + f.seek(self.DATA_SECTION_LENGTH_IN_BYTES) + token_size_as_bytes = f.read(self.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES) + self.token_size_in_bytes = int.from_bytes(token_size_as_bytes, byteorder="little", signed=False) + + # get index + if load_index: + f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len) + pkl_encoded_index = f.read() + # contains the start offset and length of each segment + # as byte positions in the data section + self._index_base: list[tuple[int, int]] = pickle.loads(pkl_encoded_index) + else: + self._index_base = None + + # initialize memmapped data section + self._data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,)) + + @property + def index_base(self) -> list[tuple[int, int]]: + if self._index_base is None: + raise ValueError("Index was not loaded. Set `load_index=True` during initialization.") + return self._index_base + + @property + def data(self) -> np.ndarray: + return self._data + + +def join_embedded_stream_data(stream_data: list[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): + """ + Joins the embedded stream data into a single file. + + Args: + stream_data (list[EmbeddedStreamData]): A list of EmbeddedStreamData objects representing the stream data. + target_file (Path): The target file to write the joined data to. + chunk_size (int, optional): The size of each data chunk. Defaults to 2048. + + Raises: + FileExistsError: If the target file already exists. + + Returns: + None + """ + if target_file.exists(): + raise FileExistsError(f'Target File at "{target_file}" exists!') + data_len = sum(d.data_len for d in stream_data) + assert len({d.token_size_in_bytes for d in stream_data}) == 1, ( + "Found different token representation sizes. This could indicate the usage of different tokenizers. " + "Not supported!" + ) + token_size_in_bytes = stream_data[0].token_size_in_bytes + + num_data_chunks = sum(math.ceil(d.data_len / chunk_size) for d in stream_data) + data_stream_generator = (d.data[i : i + chunk_size] for d in stream_data for i in range(0, d.data_len, chunk_size)) + + num_entries = sum(len(d.index_base) for d in stream_data) + + def index_stream_generator() -> Iterator[tuple[int, int]]: + # generates a stream of index offsets and segment lengths. + curr_offset = 0 + for embedded_stream_data in stream_data: + for entry_offset, segment_length in embedded_stream_data.index_base: + yield entry_offset + curr_offset, segment_length + curr_offset += embedded_stream_data.data_len + curr_offset -= embedded_stream_data.HEADER_SIZE_IN_BYTES + + with target_file.open("wb") as fout: + fout.write(data_len.to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) + fout.write( + token_size_in_bytes.to_bytes(EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little") + ) + for data_chunk in tqdm(data_stream_generator, total=num_data_chunks, desc="Writing Data Chunks..."): + fout.write(data_chunk) + + joint_index = [entry for entry in tqdm(index_stream_generator(), total=num_entries, desc="Concatenating Index")] + pickled_index = pickle.dumps(joint_index) + pickled_index_as_chunks = (pickled_index[i : i + chunk_size] for i in range(0, len(pickled_index), chunk_size)) + num_index_chunks = math.ceil(len(pickled_index) / chunk_size) + for index_chunk in tqdm(pickled_index_as_chunks, total=num_index_chunks, desc="Writing Index Chunks..."): + fout.write(index_chunk) diff --git a/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py b/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py new file mode 100644 index 000000000..1162a4cf1 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/large_file_lines_reader.py @@ -0,0 +1,299 @@ +import mmap +import pickle +from abc import ABC, abstractmethod +from enum import Enum +from pathlib import Path +from typing import Optional + +import numpy as np +from pydantic import BaseModel + +from modalities.dataloader.preprocessing.tokenization.queue_items import Sample +from modalities.exceptions import ReaderIndexationError + + +class BaseReaderIF(ABC): + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError + + @abstractmethod + def __getitem__(self, key: int) -> Sample: + raise NotImplementedError + + +class LocalLargeFileLinesReader(BaseReaderIF): + """LargeFileLinesReader class that read lines from a large file efficiently.""" + + def __init__( + self, + raw_data_path: Path, + index_path: Optional[Path] = None, + encoding: Optional[str] = "utf-8", + use_sample_length_from_index: bool = True, + ): + """ + Initializes a LargeFileLinesReader object. + + Args: + raw_data_path (Path): Path to a jsonl file, which holds text data. + index_path (Optional[Path]): Path to an index file, which indicates the start character/byte position + and length of samples given in `raw_data_path`. + If not defined, an index next to `raw_data_path` is picked, + by replacing its suffix with ".idx". + encoding (Optional[str]): The encoding of the file (default: "utf-8"). + If encoding is None, the raw data is read as bytes. + use_sample_length_from_index (bool): If True, the sample length is taken from the index file + i.e., the (offset, sample_length) pairs. If False, the sample length is calculated + as the difference between the starting point of the next and the current sample. + Returns: + None + """ + self.encoding = encoding + self.raw_data_path = raw_data_path + self.index_path = self.default_index_path(self.raw_data_path, index_path) + self.use_sample_length_from_index = use_sample_length_from_index + + if not self.raw_data_path.is_file(): + raise FileNotFoundError("Raw data file does not exist") + if not self.index_path.is_file(): + raise FileNotFoundError("Index file does not exist. Use `modalities data create_raw_index` to create one.") + + with self.index_path.open("rb") as f: + self.index = pickle.load(f) + + self.raw_data_fd = self.raw_data_path.open("rb") + self.mmapped_data_file = mmap.mmap(self.raw_data_fd.fileno(), 0, access=mmap.ACCESS_READ) + + def close(self): + self.mmapped_data_file.close() + self.raw_data_fd.close() + + @staticmethod + def default_index_path(raw_data_path: Path, index_path: Optional[Path] = None) -> Path: + """ + Returns the default index path for the given raw data path. + + Args: + raw_data_path (Path): The path to the raw data file. + index_path (Optional[Path]): The path to the index file (default: None). + + Returns: + Path: The default index path. + + Note: + If `index_path` is not provided, the default index path is generated by + appending the extension ".idx" to the stem of the `raw_data_path`. + """ + if index_path is None: + default_index_path = Path(raw_data_path.parent, f"{raw_data_path.stem}.idx") + print(f"No specific Index Path provided. Pointing to index next to input data at: {default_index_path}") + return default_index_path + return index_path + + def __len__(self) -> int: + """ + Returns the length of the index. + + Returns: + int: The length of the index. + """ + return len(self.index) + + def __getitem__(self, key: int) -> Sample: + """ + Retrieves an item from the LargeFileLinesReader. + + Args: + key (int): The index used to retrieve the item. + + Returns: + Sample: The item retrieved from the LargeFileLinesReader. + + Raises: + IndexError: If the key is out of range. + + """ + + offset, sample_length_in_bytes = self.index[key] + + # If use_sample_length_from_index = False, we calculate the sample length as the difference between the + # starting point of the next and the current sample. + # This allows for reading in the entire sample including the newline character. + if not self.use_sample_length_from_index: + if key + 1 < len(self.index): + sample_length_in_bytes = self.index[key + 1][0] - self.index[key][0] + else: + sample_length_in_bytes = len(self.mmapped_data_file) - offset + + content = self._read_from_raw_file(offset, sample_length_in_bytes) + return Sample( + raw_data_path=self.raw_data_path, + incremental_line_id=key, + shuffled_line_id=key, # TODO so far we don't support shuffling here! + offset=offset, + sample_length_in_bytes=sample_length_in_bytes, + content_raw=content, + ) + + def _read_from_raw_file(self, offset: int, sample_length_in_bytes: int) -> str | bytes: + # Reads a specified number of bytes from a raw file starting from a given offset. + data = self.mmapped_data_file[offset : offset + sample_length_in_bytes] + if self.encoding is not None: + data_decoded = data.decode(self.encoding) + return data_decoded + return data + + +class GlobalLargeFileLinesReader(BaseReaderIF): + """LargeFileLinesReader class that read lines from a large file efficiently.""" + + def __init__( + self, + global_inorder_index_path: Path, + raw_data_file_list_path: Path, + raw_data_root_path: Path, + global_shuffle_index_path: Optional[Path] = None, + encoding: Optional[str] = "utf-8", + ): + self.global_inorder_index_path = global_inorder_index_path + self.raw_data_file_list_path = raw_data_file_list_path + self.raw_data_root_path = raw_data_root_path + self.global_shuffle_index_path = global_shuffle_index_path + self.encoding = encoding + + # create the raw data file path list (the JSONL files) + # the file paths are relative to the raw_data_root_path + with open(self.raw_data_file_list_path, "r", encoding="utf-8") as f: + self.relative_raw_data_file_paths = [line.strip() for line in f.readlines()] + + self.relative_to_absolute_raw_data_file_paths = { + rel_file_path: raw_data_root_path / rel_file_path for rel_file_path in self.relative_raw_data_file_paths + } + + # open memmap / index files + num_rows, _, _ = np.memmap(self.global_inorder_index_path, dtype="int64", mode="r")[0:3] + + self.global_index_inorder = np.memmap( + self.global_inorder_index_path, dtype="int64", mode="r", shape=(num_rows, 3) + ) + if self.global_shuffle_index_path is not None: + self.global_shuffle_index = np.memmap(self.global_shuffle_index_path, dtype="int64", mode="r") + else: + self.global_shuffle_index = None + # the 0th element in the global_index_inorder contains the meta data (num_rows, num_cols, is_shuffled) + # therefore we have to skip the first element when iterating. + # Note, when we iterate via the global_shuffle_index, we don't have to do this, + # as the the shuffled index does not contain index 0. + self.global_index_inorder = self.global_index_inorder[1:] + + def close(self): + pass + + def __len__(self) -> int: + """ + Returns the length of the index. + + Returns: + int: The length of the index. + """ + if self.global_shuffle_index is not None: + return len(self.global_shuffle_index) + else: + return len(self.global_index_inorder) + + def __getitem__(self, key: int) -> Sample: + """ + Retrieves an item from the LargeFileLinesReader. + + Args: + key (int): The index used to retrieve the item. + + Returns: + Sample: The item retrieved from the LargeFileLinesReader. + + Raises: + IndexError: If the key is out of range. + + """ + try: + if self.global_shuffle_index is not None: + mapped_key = self.global_shuffle_index[key] + else: + mapped_key = key + file_index, offset, sample_length_in_bytes = self.global_index_inorder[mapped_key] + rel_file_path = self.relative_raw_data_file_paths[file_index] + abs_raw_file_path = self.relative_to_absolute_raw_data_file_paths[rel_file_path] + except Exception as e: + raise ReaderIndexationError(f"Error while reading sample with key {key}: {e}") from e + + with open(abs_raw_file_path, "rb") as fd: + raw_data_mmap = mmap.mmap(fd.fileno(), 0, access=mmap.ACCESS_READ) + content = bytes(raw_data_mmap[offset : offset + sample_length_in_bytes]) + raw_data_mmap.close() # Explicitly close mmap + + if self.encoding is not None: + content = content.decode(self.encoding) + return Sample( + incremental_line_id=key, + shuffled_line_id=mapped_key, + raw_data_path=abs_raw_file_path, + offset=offset, + sample_length_in_bytes=sample_length_in_bytes, + content_raw=content, + ) + + +class LargeFileLinesReaderTypes(Enum): + LOCAL = "LOCAL" + GLOBAL = "GLOBAL" + + +class IndexTypes(Enum): + LOCAL = "LOCAL" + GLOBAL = "GLOBAL" + + +class LocalLargeFileLinesReaderConfig(BaseModel): + raw_data_path: Path + index_path: Optional[Path] = None + encoding: Optional[str] = "utf-8" + + +class GlobalLargeFileLinesReaderConfig(BaseModel): + global_inorder_index_path: Path + raw_data_file_list_path: Path + raw_data_root_path: Path + global_shuffle_index_path: Optional[Path] = None + encoding: Optional[str] = "utf-8" + + +class LargeFileLinesReaderFactory: + @staticmethod + def get_local_reader( + raw_data_path: Path, + index_path: Optional[Path] = None, + encoding: Optional[str] = "utf-8", + ) -> LocalLargeFileLinesReader: + return LocalLargeFileLinesReader( + raw_data_path=raw_data_path, + index_path=index_path, + encoding=encoding, + use_sample_length_from_index=True, + ) + + @staticmethod + def get_global_reader( + global_inorder_index_path: Path, + raw_data_file_list_path: Path, + raw_data_root_path: Path, + global_shuffle_index_path: Optional[Path] = None, + encoding: Optional[str] = "utf-8", + ) -> GlobalLargeFileLinesReader: + return GlobalLargeFileLinesReader( + global_inorder_index_path=global_inorder_index_path, + raw_data_file_list_path=raw_data_file_list_path, + raw_data_root_path=raw_data_root_path, + global_shuffle_index_path=global_shuffle_index_path, + encoding=encoding, + ) diff --git a/src/modalities/dataloader/preprocessing/tokenization/queue_items.py b/src/modalities/dataloader/preprocessing/tokenization/queue_items.py new file mode 100644 index 000000000..a9754fce3 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/queue_items.py @@ -0,0 +1,20 @@ +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel + + +class Sample(BaseModel): + # If the index is not shuffled, then the incrementeal_line_id + # points to the position in the dataset + # If the index is shuffled, then the incremental_line_id + # points to the position in the shuffled index and the + # shuffled_line_id points to the position in the original index + incremental_line_id: int + raw_data_path: Path + offset: int + sample_length_in_bytes: int + content_raw: str | bytes + content_tokenized: Optional[bytes] = None + token_size_in_bytes: Optional[int] = None + shuffled_line_id: Optional[int] = None diff --git a/src/modalities/dataloader/preprocessing/tokenization/strategies.py b/src/modalities/dataloader/preprocessing/tokenization/strategies.py new file mode 100644 index 000000000..86e01c239 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/strategies.py @@ -0,0 +1,464 @@ +import math +import multiprocessing as mp +import os +import pickle +import time +from io import BufferedWriter +from pathlib import Path +from typing import Type + +import jq +from pydantic import BaseModel + +from modalities.config.instantiation_models import TokenizationInstantiationModel +from modalities.dataloader.preprocessing.queued_processing.processing_strategy_if import ProcessingStrategyIF +from modalities.dataloader.preprocessing.queued_processing.queue_items import ProgressMessage, ReadingJob +from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import EmbeddedStreamData +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import ( + BaseReaderIF, + LargeFileLinesReaderFactory, + LargeFileLinesReaderTypes, +) +from modalities.dataloader.preprocessing.tokenization.queue_items import Sample +from modalities.dataloader.preprocessing.tokenization.worker_types import WorkerTypes +from modalities.exceptions import EmptySampleError, ProcessingStrategyDoneException +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry +from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper +from modalities.utils.logging import get_logger + + +def get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int: + """ + Calculates the required number of bytes to represent an integer. + + Args: + int_to_get_repr (int): The integer to get the representation for. + + Returns: + int: The number of bytes required to represent the integer. + """ + # we currently only support token sizes of 1, 2 and 4 bytes, as implemented here: + # https://github.com/Modalities/modalities/blob/fix_char_bytes_indexation_mismatch/src/modalities/dataloader/dataset.py#L202 + num_bytes = math.ceil(math.log2(int_to_get_repr) / 8) + if num_bytes == 1: + return 1 + elif num_bytes == 2: + return 2 + elif num_bytes <= 4: + return 4 + else: + raise ValueError("Currently only support token byte sizes of 1, 2, and 4.") + + +class PopulatingStrategy(ProcessingStrategyIF): + def __init__( + self, reader_q_key: str, logging_message_q_key: str, index_start: int, num_samples: int, batch_size: int + ): + self._reader_q_key = reader_q_key + self._logging_message_q_key = logging_message_q_key + self._batch_size = batch_size + self._reading_iter = iter(range(index_start, index_start + num_samples, batch_size)) + + def __enter__(self): + return self + + def finalize(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def process(self) -> dict[str, ReadingJob | ProgressMessage]: + try: + sample_id = next(self._reading_iter) + except StopIteration as e: + raise ProcessingStrategyDoneException("PopulatingStrategy done.") from e + reading_job = ReadingJob(sample_id=sample_id, batch_size=self._batch_size) + progress_message = ProgressMessage(WorkerTypes.POPULATOR, num_samples=self._batch_size) + return {self._reader_q_key: reading_job, self._logging_message_q_key: progress_message} + + +class ReadingStrategy(ProcessingStrategyIF): + def __init__( + self, reader_type: Type[BaseReaderIF], reader_args: BaseModel, tokenizer_q_key: str, logging_message_q_key: str + ): + self._reader_type = reader_type + self._reader_args = reader_args + self._reader = None + self._tokenizer_q_key = tokenizer_q_key + self._logging_message_q_key = logging_message_q_key + + def __enter__(self): + self._reader = self._reader_type(**self._reader_args.model_dump()) + return self + + def finalize(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self._reader.close() + + def process(self, item: ReadingJob) -> dict[str, list[Sample] | ProgressMessage]: + batch: list[Sample] = [self._reader[item.sample_id + i] for i in range(item.batch_size)] + progress_message = ProgressMessage(WorkerTypes.READER, len(batch)) + return {self._tokenizer_q_key: batch, self._logging_message_q_key: progress_message} + + +class TokenizingStrategy(ProcessingStrategyIF): + def __init__( + self, + ti_settings: ( + TokenizationInstantiationModel.TokenizerWorkerSettings.TokenizerSettings.TokenizerInstantitionSettings + ), + eod_token: str, + jq_pattern: str, + writer_q_key: str, + logging_message_q_key: str, + ): + self._tokenizer_instantiation_setings = ti_settings + self._eod_token = eod_token + self._jq_filter = jq.compile(jq_pattern) + self._writer_q_key = writer_q_key + self._logging_message_q_key = logging_message_q_key + self._tokenizer = None + self._token_size_in_bytes = None + self._encoded_eod_token_as_bytes = None + + def __enter__(self): + registry = Registry(COMPONENTS) + tokenizer_type: Type[TokenizerWrapper] = registry.get_component( + component_key=self._tokenizer_instantiation_setings.tokenizer_component_key, + variant_key=self._tokenizer_instantiation_setings.tokenizer_variant_key, + ) + self._tokenizer: TokenizerWrapper = tokenizer_type(**self._tokenizer_instantiation_setings.config) + + self._token_size_in_bytes = get_required_num_of_bytes_to_repr(self._tokenizer.vocab_size) + eod_token_id = self._tokenizer.get_token_id(self._eod_token) + self._encoded_eod_token_as_bytes = TokenizingStrategy._encoded_token_to_bytes( + token_size_in_bytes=self._token_size_in_bytes, encoded_token=eod_token_id + ) + return self + + def finalize(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def process(self, item: list[Sample]) -> dict[str, list[Sample] | ProgressMessage]: + batch_processed = [] + for sample in item: + processed_line = self._process_line(sample.content_raw) + sample.content_tokenized = processed_line + sample.token_size_in_bytes = self._token_size_in_bytes + batch_processed.append(sample) + progress_message = ProgressMessage(WorkerTypes.TOKENIZER, num_samples=len(batch_processed)) + return {self._writer_q_key: batch_processed, self._logging_message_q_key: progress_message} + + def _process_line(self, line: str) -> bytes: + # extracts the text via the jq_filter and applies tokenization to the extract text + jq_retrieved_text = self._jq_filter.input_text(line).first() + if jq_retrieved_text is None: + raise ValueError(f"jq was not able extract the text using the expression: {self._jq_filter}") + tokens = self._tokenizer.tokenize(jq_retrieved_text) + if len(tokens) == 0: + raise EmptySampleError("Received empty sample...") + + token_byte_string = b"".join( + map(self._encoded_token_to_bytes, [self._token_size_in_bytes] * len(tokens), tokens) + ) + if not token_byte_string.endswith(self._encoded_eod_token_as_bytes): + token_byte_string = token_byte_string + self._encoded_eod_token_as_bytes + return token_byte_string + + @staticmethod + def _encoded_token_to_bytes(token_size_in_bytes: int, encoded_token: int) -> bytes: + # Converts an encoded token to its bytes representaion. + return encoded_token.to_bytes(token_size_in_bytes, byteorder="little", signed=False) + + +class WritingStrategy(ProcessingStrategyIF): + def __init__(self, dst_path: Path, index_start: int, logging_message_q_key: str): + self._dst_path = dst_path + self._index_start = index_start + self._logging_message_q_key = logging_message_q_key + self._dst_fd = None + self._finalized = None + self._curr_offset = None + self._prev_line_id = None + self._batch_dict = None + self._index_list = None + self._has_seen_first_batch = None + + if not self._dst_path.parent.exists(): + self._dst_path.parent.mkdir(parents=True, exist_ok=True) + + def __enter__(self): + self._dst_fd = self._dst_path.open("wb") + self._finalized = False + # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) + # not possible to prepend header after determining size of data section + self._dst_fd.write((0).to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little")) + + # The offset only applies to the data section, not the header + # When we load the file, we add the header size to the offset + self._curr_offset = 0 + + self._prev_line_id = self._index_start - 1 + self._batch_dict = {} + self._index_list = [] + self._has_seen_first_batch = False + + return self + + def finalize(self): + # check that the index list IS NOT empty and the batch_dict IS empty + # i.e., all batches have been written to the file + if len(self._index_list) == 0 or len(self._batch_dict) > 0: + raise ValueError( + f"Could not finalize writing strategy. Index list is empty or batch_dict is not empty. " + f"Index list: {len(self._index_list)}, batch_dict: {self._batch_dict.keys()}" + ) + else: + # write index + self._dst_fd.write(pickle.dumps(self._index_list)) + self._dst_fd.close() + self._update_data_length_in_pre_allocated_header(self._dst_path, self._index_list) + self._finalized = True + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._finalized: + self._dst_fd.close() + # if the process was stopped due to a stop event or the index list is empty, we remove the file + get_logger(name="main").warning( + f"Removing file {self._dst_path} due to non-finalized pbin file. The pbin file either is not " + "finalized as WritingStrategy.finalize() was not called or not all samples have been written " + f"to disc. index_list: {len(self._index_list)}, batch_dict: {self._batch_dict.keys()}" + ) + os.remove(self._dst_path) + + def process(self, item: list[Sample]) -> dict[str, ProgressMessage]: + if not self._has_seen_first_batch: + # write the token size descriptor to the file + # we receive this information from the tokenizer (based on the tokenizer's vocab size) + # and is always provided within the Sample object + self._has_seen_first_batch = True + self._dst_fd.write( + item[0].token_size_in_bytes.to_bytes( + EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="little" + ) + ) + + line_id = item[0].incremental_line_id + self._batch_dict[line_id] = item + + num_samples_written = 0 + while self._prev_line_id + 1 in self._batch_dict: + batch = self._batch_dict.pop(self._prev_line_id + 1) + self._prev_line_id, self._curr_offset = WritingStrategy._write_batch( + batch, self._prev_line_id, self._curr_offset, self._index_list, self._dst_fd + ) + num_samples_written += len(batch) + progress_message = ProgressMessage(WorkerTypes.WRITER, num_samples=num_samples_written) + return {self._logging_message_q_key: progress_message} + + # writes a batch received from the writer_q to the destination file + @staticmethod + def _write_batch( + batch: list[Sample], prev_line_id: int, curr_offset: int, index_list: list, f: BufferedWriter + ) -> tuple[int, int]: + # write the tokens for each document + for sample in batch: + if prev_line_id + 1 != sample.incremental_line_id: + raise ValueError( + f"Line IDs are not consecutive. Expected {prev_line_id + 1}, but got {sample.incremental_line_id}" + ) + f.write(sample.content_tokenized) + segment_length = len(sample.content_tokenized) + index_list.append((curr_offset, segment_length)) + curr_offset += segment_length + prev_line_id = sample.incremental_line_id + return prev_line_id, curr_offset + + @staticmethod + def _update_data_length_in_pre_allocated_header(dst_path: Path, index_list: list[tuple[int, int]]): + # Update the length of the data section in the pre-allocated header of the destination file. + # The data segment length is sum of the starting position and the length of the last document. + length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1] + data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes( + EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="little" + ) + with dst_path.open("rb+") as fout: + fout.seek(0) + fout.write(data_section_length_in_bytes) + + +class ProgressLoggingStrategy(ProcessingStrategyIF): + def __init__( + self, + logging_interval: int, + total_num_samples: int, + q_dict: dict[str, mp.Queue], + ): + self._logging_interval = logging_interval + self._total_num_samples = total_num_samples + self._worker_to_pid_to_num_samples: dict[WorkerTypes, dict[int, int]] = {} + self._worker_type_to_processed_num_samples = {worker_type: 0 for worker_type in WorkerTypes} + self._q_dict = q_dict + self._last_logged = None + + def __enter__(self): + self._last_logged = time.time() + + def finalize(self): + passed_time = time.time() - self._last_logged + self._log_and_reset(passed_time) + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def process(self, item: ProgressMessage) -> dict: + self._add_progress_message(item) + passed_time = time.time() - self._last_logged + if passed_time > self._logging_interval: + self._log_and_reset(passed_time) + self._last_logged = time.time() + return {} + + def _add_progress_message(self, progress_message: ProgressMessage): + if progress_message.worker_type not in self._worker_to_pid_to_num_samples: + self._worker_to_pid_to_num_samples[progress_message.worker_type] = {} + + if progress_message.process_id not in self._worker_to_pid_to_num_samples[progress_message.worker_type]: + self._worker_to_pid_to_num_samples[progress_message.worker_type][progress_message.process_id] = 0 + + self._worker_to_pid_to_num_samples[progress_message.worker_type][ + progress_message.process_id + ] += progress_message.num_samples + self._worker_type_to_processed_num_samples[progress_message.worker_type] += progress_message.num_samples + + def _log_and_reset(self, passed_time: int): + logging_message = f"\n==================Progress report (last {passed_time}s) ==================\n" + + logging_message += "Total progress: \n" + for worker_type, processed_num_samples in self._worker_type_to_processed_num_samples.items(): + m = ( + f"\t{worker_type.name}: {processed_num_samples}/{self._total_num_samples} samples " + f"({processed_num_samples/self._total_num_samples*100}%)\n" + ) + logging_message += m + + logging_message += "\n" + logging_message += "Aggregated Throughput: \n" + + for worker_type, pid_to_num_samples in self._worker_to_pid_to_num_samples.items(): + total_samples = sum(pid_to_num_samples.values()) + logging_message += f"\t{worker_type.name} workers: {total_samples/passed_time} samples/s.\n" + logging_message += "\n" + logging_message += "Worker Throughput: \n" + for worker_type, pid_to_num_samples in self._worker_to_pid_to_num_samples.items(): + logging_message += f"{worker_type.name} workers:\n" + for pid, num_samples in pid_to_num_samples.items(): + logging_message += f"\t{worker_type.name} {pid}: {num_samples/passed_time} samples/s.\n" + logging_message += "\n" + logging_message += "\n" + + logging_message += "Queues: \n" + for q_key, q in self._q_dict.items(): + logging_message += f"\t{q_key}: {q.qsize()} batches (approx.)\n" + + get_logger().info("%s", logging_message) + + # reset values + for worker_type in self._worker_to_pid_to_num_samples.keys(): + self._worker_to_pid_to_num_samples[worker_type] = { + pid: 0 for pid in self._worker_to_pid_to_num_samples[worker_type].keys() + } + + +class ProcessingStrategyFactory: + @staticmethod + def get_populating_strategy( + reader_q_key: str, logging_message_q_key: str, index_start: int, num_samples: int, batch_size: int + ) -> PopulatingStrategy: + return PopulatingStrategy( + reader_q_key=reader_q_key, + logging_message_q_key=logging_message_q_key, + index_start=index_start, + num_samples=num_samples, + batch_size=batch_size, + ) + + @staticmethod + def get_reader_strategy( + reader_settings: TokenizationInstantiationModel.ReaderWorkerSettings.ReaderSettings, + tokenizer_q_key: str, + logging_message_q_key: str, + ) -> ReadingStrategy: + reader_type = reader_settings.reader_type + if reader_type == LargeFileLinesReaderTypes.LOCAL: + return ReadingStrategy( + LargeFileLinesReaderFactory.get_local_reader, + reader_settings.reader_args, + tokenizer_q_key, + logging_message_q_key, + ) + elif reader_type == LargeFileLinesReaderTypes.GLOBAL: + return ReadingStrategy( + LargeFileLinesReaderFactory.get_global_reader, + reader_settings.reader_args, + tokenizer_q_key, + logging_message_q_key, + ) + else: + raise ValueError(f"Reader type {reader_type} is not supported.") + + @staticmethod + def get_tokenizer_strategy( + tokenizer_settings: TokenizationInstantiationModel.TokenizerWorkerSettings.TokenizerSettings, + writer_q_key: str, + logging_message_q_key: str, + ) -> TokenizingStrategy: + tokenizing_strategy = TokenizingStrategy( + ti_settings=tokenizer_settings.tokenizer_instantiation_settings, + eod_token=tokenizer_settings.eod_token, + jq_pattern=tokenizer_settings.jq_pattern, + writer_q_key=writer_q_key, + logging_message_q_key=logging_message_q_key, + ) + return tokenizing_strategy + + @staticmethod + def get_writing_strategy( + ww_settings: TokenizationInstantiationModel.WriterWorkerSettings, + logging_message_q_key: str, + ) -> WritingStrategy: + writing_strategy = WritingStrategy( + dst_path=ww_settings.dst_path, + index_start=ww_settings.index_start, + logging_message_q_key=logging_message_q_key, + ) + return writing_strategy + + @staticmethod + def get_progress_logging_strategy( + logging_interval: int, + total_num_samples: int, + q_dict: dict[str, mp.Queue], + ) -> ProgressLoggingStrategy: + return ProgressLoggingStrategy( + logging_interval=logging_interval, + total_num_samples=total_num_samples, + q_dict=q_dict, + ) + + @staticmethod + def get_process_queues( + reader_q_maxsize: int, tokenizer_q_maxsize: int, writer_q_maxsize + ) -> tuple[mp.Queue, mp.Queue, mp.Queue, mp.Queue]: + reader_q = mp.Queue(maxsize=reader_q_maxsize) # containes line_ids to be read + tokenizer_q = mp.Queue(maxsize=tokenizer_q_maxsize) # contains (line_id, line) pairs to be tokenized + writer_q = mp.Queue(maxsize=writer_q_maxsize) # contains (line_id, tokenized_line) to be written to disc + logging_message_q = mp.Queue() + return reader_q, tokenizer_q, writer_q, logging_message_q diff --git a/src/modalities/dataloader/preprocessing/tokenization/worker_types.py b/src/modalities/dataloader/preprocessing/tokenization/worker_types.py new file mode 100644 index 000000000..90ae24cf0 --- /dev/null +++ b/src/modalities/dataloader/preprocessing/tokenization/worker_types.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class WorkerTypes(Enum): + POPULATOR = "POPULATOR" + READER = "READER" + TOKENIZER = "TOKENIZER" + WRITER = "WRITER" + LOGGING = "LOGGING" diff --git a/src/modalities/exceptions.py b/src/modalities/exceptions.py index 0ac49dcc1..03d0abc17 100644 --- a/src/modalities/exceptions.py +++ b/src/modalities/exceptions.py @@ -24,3 +24,23 @@ class OptimizerError(Exception): class ConfigError(Exception): pass + + +class EmptySampleError(RuntimeError): + pass + + +class ReaderIndexationError(Exception): + pass + + +class ProcessorStopEventException(Exception): + pass + + +class ProcessorException(Exception): + pass + + +class ProcessingStrategyDoneException(Exception): + pass diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index ecda86995..f712a6fe1 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -54,6 +54,11 @@ from modalities.dataloader.dataloader_factory import DataloaderFactory from modalities.dataloader.dataset import DummyDatasetConfig from modalities.dataloader.dataset_factory import DatasetFactory +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import ( + GlobalLargeFileLinesReaderConfig, + LargeFileLinesReaderFactory, + LocalLargeFileLinesReaderConfig, +) from modalities.dataloader.samplers import ResumableDistributedSampler from modalities.logging_broker.subscriber_impl.subscriber_factory import ( ProgressSubscriberFactory, @@ -87,14 +92,16 @@ from modalities.utils.number_conversion import ( LocalNumBatchesFromNumSamplesConfig, LocalNumBatchesFromNumTokensConfig, - NumberConversion, NumberConversionFromCheckpointPathConfig, NumSamplesFromNumTokensConfig, + NumSamplesFromReaderConfig, NumStepsFromNumSamplesConfig, NumStepsFromNumTokensConfig, NumStepsFromRawDatasetIndexConfig, NumTokensFromNumStepsConfig, NumTokensFromPackedMemMapDatasetContinuousConfig, + PreprocessingNumberConversion, + TrainingNumberConversion, ) @@ -246,83 +253,102 @@ class ComponentEntity: "gradient_clipper", "fsdp_logging_only", FSDPLoggingOnlyGradientClipper, FSDPDummyGradientClipperConfig ), ComponentEntity("gradient_clipper", "dummy", DummyGradientClipper, DummyGradientClipperConfig), + # large file lines reader + ComponentEntity( + "large_file_lines_reader", + "local", + LargeFileLinesReaderFactory.get_local_reader, + LocalLargeFileLinesReaderConfig, + ), + ComponentEntity( + "large_file_lines_reader", + "global", + LargeFileLinesReaderFactory.get_local_reader, + GlobalLargeFileLinesReaderConfig, + ), # Number conversion ComponentEntity( - "number_conversion", + "training_number_conversion", "local_num_batches_from_num_samples", - NumberConversion.get_local_num_batches_from_num_samples, + TrainingNumberConversion.get_local_num_batches_from_num_samples, LocalNumBatchesFromNumSamplesConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "local_num_batches_from_num_tokens", - NumberConversion.get_local_num_batches_from_num_tokens, + TrainingNumberConversion.get_local_num_batches_from_num_tokens, LocalNumBatchesFromNumTokensConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_samples_from_num_tokens", - NumberConversion.get_num_samples_from_num_tokens, + TrainingNumberConversion.get_num_samples_from_num_tokens, NumSamplesFromNumTokensConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_steps_from_num_samples", - NumberConversion.get_num_steps_from_num_samples, + TrainingNumberConversion.get_num_steps_from_num_samples, NumStepsFromNumSamplesConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_steps_from_num_tokens", - NumberConversion.get_num_steps_from_num_tokens, + TrainingNumberConversion.get_num_steps_from_num_tokens, NumStepsFromNumTokensConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_tokens_from_num_steps", - NumberConversion.get_num_tokens_from_num_steps, + TrainingNumberConversion.get_num_tokens_from_num_steps, NumTokensFromNumStepsConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "last_step_from_checkpoint_path", - NumberConversion.get_last_step_from_checkpoint_path, + TrainingNumberConversion.get_last_step_from_checkpoint_path, NumberConversionFromCheckpointPathConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_seen_steps_from_checkpoint_path", - NumberConversion.get_num_seen_steps_from_checkpoint_path, + TrainingNumberConversion.get_num_seen_steps_from_checkpoint_path, NumberConversionFromCheckpointPathConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "global_num_seen_tokens_from_checkpoint_path", - NumberConversion.get_global_num_seen_tokens_from_checkpoint_path, + TrainingNumberConversion.get_global_num_seen_tokens_from_checkpoint_path, NumberConversionFromCheckpointPathConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_target_steps_from_checkpoint_path", - NumberConversion.get_num_target_steps_from_checkpoint_path, + TrainingNumberConversion.get_num_target_steps_from_checkpoint_path, NumberConversionFromCheckpointPathConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "global_num_target_tokens_from_checkpoint_path", - NumberConversion.get_global_num_target_tokens_from_checkpoint_path, + TrainingNumberConversion.get_global_num_target_tokens_from_checkpoint_path, NumberConversionFromCheckpointPathConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_tokens_from_packed_mem_map_dataset_continuous", - NumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous, + TrainingNumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous, NumTokensFromPackedMemMapDatasetContinuousConfig, ), ComponentEntity( - "number_conversion", + "training_number_conversion", "num_steps_from_raw_dataset_index", - NumberConversion.get_num_steps_from_raw_dataset_index, + TrainingNumberConversion.get_num_steps_from_raw_dataset_index, NumStepsFromRawDatasetIndexConfig, ), + ComponentEntity( + "preprocessing_number_conversion", + "num_samples", + PreprocessingNumberConversion.get_num_samples_from_reader, + NumSamplesFromReaderConfig, + ), ] diff --git a/src/modalities/tokenization/tokenizer_wrapper.py b/src/modalities/tokenization/tokenizer_wrapper.py index 211f5801f..f02643c17 100644 --- a/src/modalities/tokenization/tokenizer_wrapper.py +++ b/src/modalities/tokenization/tokenizer_wrapper.py @@ -260,7 +260,7 @@ def get_token_id(self, token: str) -> int: if not isinstance(piece_id, int): raise ValueError("Token cannot be represented by a single token ID!") if piece_id == self.tokenizer.unk_id(): - raise ValueError("Token cannot be represented by a single token id!") + raise ValueError("Token cannot be represented by a single token id!") return piece_id def is_special_token_id(self, token_id: int) -> bool: diff --git a/src/modalities/utils/env_variables.py b/src/modalities/utils/env_variables.py new file mode 100644 index 000000000..89d7150df --- /dev/null +++ b/src/modalities/utils/env_variables.py @@ -0,0 +1,56 @@ +import os +from contextlib import contextmanager +from functools import wraps +from typing import Any + + +@contextmanager +def temporary_env_var(key, value): + """ + Temporarily set an environment variable. + + Args: + key (str): The environment variable name. + value (str): The temporary value to set. + """ + original_value = os.environ.get(key) # Store the original value (if any) + os.environ[key] = value # Set the temporary value + try: + yield # Allow code execution within the context + finally: + # Restore the original value or delete the key if it wasn't set originally + if original_value is None: + del os.environ[key] + else: + os.environ[key] = original_value + + +def temporary_env_vars_decorator(env_vars: dict[str, Any]): + """ + Decorator to temporarily set multiple environment variables for the duration of a function call. + + Args: + env_vars (dict): A dictionary of environment variable names and their temporary values. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + original_values = {} # Store original values of environment variables + try: + # Set the temporary environment variables + for key, value in env_vars.items(): + original_values[key] = os.environ.get(key) # Save original value + os.environ[key] = value # Set temporary value + return func(*args, **kwargs) # Execute the decorated function + finally: + # Restore original values or delete keys if not originally set + for key, original_value in original_values.items(): + if original_value is None: + del os.environ[key] + else: + os.environ[key] = original_value + + return wrapper + + return decorator diff --git a/src/modalities/utils/logging.py b/src/modalities/utils/logging.py new file mode 100644 index 000000000..21eda110b --- /dev/null +++ b/src/modalities/utils/logging.py @@ -0,0 +1,10 @@ +import logging + +def get_logger(name: str = "main") -> logging.Logger: + logger = logging.getLogger(name) + if not logger.handlers: + logger.setLevel(logging.DEBUG) + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter('%(name)s - %(levelname)s - %(message)s')) + logger.addHandler(handler) + return logger \ No newline at end of file diff --git a/src/modalities/utils/number_conversion.py b/src/modalities/utils/number_conversion.py index 3dc56732b..201183b5f 100644 --- a/src/modalities/utils/number_conversion.py +++ b/src/modalities/utils/number_conversion.py @@ -1,10 +1,13 @@ import re +from functools import lru_cache from pathlib import Path from typing import Annotated from pydantic import BaseModel, Field +from modalities.config.pydanctic_if_types import PydanticBaseReaderIFType from modalities.dataloader.dataset_factory import DatasetFactory +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import BaseReaderIF class LocalNumBatchesFromNumSamplesConfig(BaseModel): @@ -67,7 +70,13 @@ class NumStepsFromRawDatasetIndexConfig(BaseModel): gradient_accumulation_steps: Annotated[int, Field(strict=True, gt=0)] -class NumberConversion: +class NumSamplesFromReaderConfig(BaseModel): + reader: PydanticBaseReaderIFType + index_start: Annotated[int, Field(strict=True, ge=0)] = 0 + num_samples: Annotated[int, Field(strict=True, ge=1)] = None + + +class TrainingNumberConversion: @staticmethod def _get_checkpoint_parameter_value(pattern: str, string: str) -> int: matches = re.findall(pattern, string) @@ -134,7 +143,7 @@ def get_local_num_batches_from_num_tokens( int: Number of local batches for single rank. """ global_num_samples = global_num_tokens // sequence_length - return NumberConversion.get_local_num_batches_from_num_samples( + return TrainingNumberConversion.get_local_num_batches_from_num_samples( num_ranks=num_ranks, global_num_samples=global_num_samples, local_micro_batch_size=local_micro_batch_size ) @@ -178,7 +187,7 @@ def get_num_steps_from_num_tokens( int: Number of steps. """ global_num_samples = global_num_tokens // sequence_length - return NumberConversion.get_num_steps_from_num_samples( + return TrainingNumberConversion.get_num_steps_from_num_samples( num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, global_num_samples=global_num_samples, @@ -221,7 +230,7 @@ def get_last_step_from_checkpoint_path(checkpoint_path: Path) -> int: """ # Regex pattern to match 'num_steps_' followed by digits pattern = r"seen_steps_(\d+)" - num_seen_steps = NumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) + num_seen_steps = TrainingNumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) return num_seen_steps - 1 @staticmethod @@ -236,7 +245,7 @@ def get_num_seen_steps_from_checkpoint_path(checkpoint_path: Path) -> int: """ # Regex pattern to match 'num_steps_' followed by digits pattern = r"seen_steps_(\d+)" - num_seen_steps = NumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) + num_seen_steps = TrainingNumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) return num_seen_steps @staticmethod @@ -251,7 +260,7 @@ def get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path: Path) -> in """ # Regex pattern to match 'num_steps_' followed by digits pattern = r"seen_tokens_(\d+)" - num_seen_tokens = NumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) + num_seen_tokens = TrainingNumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) return num_seen_tokens @staticmethod @@ -266,16 +275,18 @@ def get_global_num_target_tokens_from_checkpoint_path(checkpoint_path: Path) -> """ # Regex pattern to match 'num_steps_' followed by digits pattern = r"target_tokens_(\d+)" - num_target_tokens = NumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) + num_target_tokens = TrainingNumberConversion._get_checkpoint_parameter_value(pattern, str(checkpoint_path)) return num_target_tokens @staticmethod def get_num_target_steps_from_checkpoint_path(checkpoint_path: Path) -> int: - tokens_per_step = NumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path) / ( - NumberConversion.get_last_step_from_checkpoint_path(checkpoint_path) + 1 + tokens_per_step = TrainingNumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path) / ( + TrainingNumberConversion.get_last_step_from_checkpoint_path(checkpoint_path) + 1 ) - global_num_target_tokens = NumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path) + global_num_target_tokens = TrainingNumberConversion.get_global_num_target_tokens_from_checkpoint_path( + checkpoint_path + ) num_target_steps = global_num_target_tokens // tokens_per_step if isinstance(num_target_steps, float) and not num_target_steps.is_integer(): @@ -315,7 +326,7 @@ def get_num_tokens_from_packed_mem_map_dataset_continuous( raw_data_path=dataset_path, sequence_length=sequence_length, sample_key="text" ) global_num_tokens_dataset = len(dataset) * sequence_length - num_steps = NumberConversion.get_num_steps_from_num_tokens( + num_steps = TrainingNumberConversion.get_num_steps_from_num_tokens( num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, global_num_tokens=global_num_tokens_dataset, @@ -323,7 +334,7 @@ def get_num_tokens_from_packed_mem_map_dataset_continuous( gradient_accumulation_steps=gradient_accumulation_steps, ) - global_num_tokens_actual = NumberConversion.get_num_tokens_from_num_steps( + global_num_tokens_actual = TrainingNumberConversion.get_num_tokens_from_num_steps( num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, sequence_length=sequence_length, @@ -356,10 +367,25 @@ def get_num_steps_from_raw_dataset_index( """ index = DatasetFactory.get_raw_index(raw_index_path=raw_index_path) num_samples = len(index) - num_steps = NumberConversion.get_num_steps_from_num_samples( + num_steps = TrainingNumberConversion.get_num_steps_from_num_samples( num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, global_num_samples=num_samples, gradient_accumulation_steps=gradient_accumulation_steps, ) return num_steps + + +class PreprocessingNumberConversion: + @lru_cache(maxsize=128) + @staticmethod + def get_num_samples_from_reader(reader: BaseReaderIF, index_start: int = 0, num_samples: int = None): + max_num_samples = len(reader) - index_start + if num_samples is not None and num_samples > max_num_samples: + raise ValueError( + f"num_samples ({num_samples}) is greater than the maximum number of samples " + f"(len(large_file_lines_reader) - index_start = {max_num_samples})" + ) + if num_samples is None: + num_samples = max_num_samples + return num_samples diff --git a/tests/conftest.py b/tests/conftest.py index c05cb2c80..25b5d3bce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,9 +11,9 @@ from modalities.checkpointing.checkpoint_saving import CheckpointSaving from modalities.config.config import load_app_config_dict -from modalities.dataloader.create_index import IndexGenerator from modalities.dataloader.dataloader import LLMDataLoader -from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader +from modalities.dataloader.preprocessing.indexation.local_indexation import IndexGenerator +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader from modalities.evaluator import Evaluator from modalities.logging_broker.publisher import MessagePublisher from modalities.loss_functions import Loss @@ -67,7 +67,7 @@ def dummy_data_path(tmpdir) -> DataPathCollection: source_raw_dummy_data_path = _ROOT_DIR / Path("./data/lorem_ipsum.jsonl") dummy_data_path = Path(tmpdir, source_raw_dummy_data_path.name) dummy_data_path.write_text(source_raw_dummy_data_path.read_text()) - index_path = LargeFileLinesReader.default_index_path(dummy_data_path) + index_path = LocalLargeFileLinesReader.default_index_path(dummy_data_path) index_path.unlink(missing_ok=True) return DataPathCollection(raw_data_path=dummy_data_path, index_path=index_path) @@ -77,7 +77,7 @@ def dummy_data_path_long(tmpdir) -> DataPathCollection: source_raw_dummy_data_path = _ROOT_DIR / Path("./data/lorem_ipsum_long.jsonl") dummy_data_path = Path(tmpdir, source_raw_dummy_data_path.name) dummy_data_path.write_text(source_raw_dummy_data_path.read_text()) - index_path = LargeFileLinesReader.default_index_path(dummy_data_path) + index_path = LocalLargeFileLinesReader.default_index_path(dummy_data_path) index_path.unlink(missing_ok=True) return DataPathCollection(raw_data_path=dummy_data_path, index_path=index_path) diff --git a/tests/dataloader/test_large_file_lines_reader.py b/tests/dataloader/test_large_file_lines_reader.py index 1234c5edd..177008142 100644 --- a/tests/dataloader/test_large_file_lines_reader.py +++ b/tests/dataloader/test_large_file_lines_reader.py @@ -7,7 +7,7 @@ import pytest from modalities.dataloader.create_index import IndexGenerator -from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader +from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader from tests.conftest import DataPathCollection @@ -75,7 +75,7 @@ def generate_data_index_file(data_path: Path, **kwargs): ) def test_large_file_lines_reader_text(indexed_dummy_data_path: DataPathCollection, use_sample_length_from_index: bool): raw_data_path = indexed_dummy_data_path.raw_data_path - reader = LargeFileLinesReader( + reader = LocalLargeFileLinesReader( raw_data_path, use_sample_length_from_index=use_sample_length_from_index, encoding="utf-8" ) assert raw_data_path.read_text().count("\n") == 12 @@ -106,10 +106,10 @@ def test_large_file_lines_reader_binary_text_equivalence( indexed_dummy_data_path: DataPathCollection, use_sample_length_from_index: bool ): raw_data_path = indexed_dummy_data_path.raw_data_path - reader_binary = LargeFileLinesReader( + reader_binary = LocalLargeFileLinesReader( raw_data_path, use_sample_length_from_index=use_sample_length_from_index, encoding=None ) - reader_text = LargeFileLinesReader( + reader_text = LocalLargeFileLinesReader( raw_data_path, use_sample_length_from_index=use_sample_length_from_index, encoding="utf-8" ) @@ -124,4 +124,4 @@ def test_large_file_lines_reader_missing_source_data(dummy_data_path: DataPathCo raw_data_path.unlink(missing_ok=True) assert not raw_data_path.exists() with pytest.raises(FileNotFoundError): - LargeFileLinesReader(raw_data_path, dummy_data_path.index_path) + LocalLargeFileLinesReader(raw_data_path, dummy_data_path.index_path) diff --git a/tests/dataloader/yaml_configs/skipped_dataloader.yaml b/tests/dataloader/yaml_configs/skipped_dataloader.yaml index ddd81bbe1..b7f5910e8 100644 --- a/tests/dataloader/yaml_configs/skipped_dataloader.yaml +++ b/tests/dataloader/yaml_configs/skipped_dataloader.yaml @@ -29,7 +29,7 @@ train_dataset: sample_key: ${settings.referencing_keys.sample_key} skip_num_samples: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_samples_from_num_tokens config: num_tokens: ${settings.training.global_num_seen_tokens} diff --git a/tests/end2end_tests/gpt2_train_num_steps_8.yaml b/tests/end2end_tests/gpt2_train_num_steps_8.yaml index 4954e6a92..cef1a88a0 100644 --- a/tests/end2end_tests/gpt2_train_num_steps_8.yaml +++ b/tests/end2end_tests/gpt2_train_num_steps_8.yaml @@ -27,7 +27,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -36,7 +36,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml b/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml index 1a7c9da6b..5bf7b0669 100644 --- a/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml +++ b/tests/end2end_tests/gpt2_warm_start_from_step_4.yaml @@ -27,28 +27,28 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: global_num_target_tokens_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_target_steps_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} training_progress: global_num_seen_tokens: # used below - component_key: number_conversion + component_key: training_number_conversion variant_key: global_num_seen_tokens_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} num_seen_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_seen_steps_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} num_seen_samples: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_samples_from_num_tokens config: num_tokens: ${settings.training_progress.global_num_seen_tokens} diff --git a/tests/test_yaml_configs/config_lorem_ipsum.yaml b/tests/test_yaml_configs/config_lorem_ipsum.yaml index e9552785b..f0582a89f 100644 --- a/tests/test_yaml_configs/config_lorem_ipsum.yaml +++ b/tests/test_yaml_configs/config_lorem_ipsum.yaml @@ -26,7 +26,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -35,7 +35,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/tests/utils/test_number_conversion.py b/tests/utils/test_number_conversion.py index f54d807ae..531aaad98 100644 --- a/tests/utils/test_number_conversion.py +++ b/tests/utils/test_number_conversion.py @@ -4,7 +4,7 @@ import pytest from modalities.dataloader.dataset_factory import DatasetFactory -from modalities.utils.number_conversion import NumberConversion +from modalities.utils.number_conversion import TrainingNumberConversion @pytest.mark.parametrize( @@ -15,7 +15,9 @@ def test_get_local_num_batches_from_num_samples( num_ranks: int, global_num_samples: int, local_micro_batch_size: int, expected: int ): assert ( - NumberConversion.get_local_num_batches_from_num_samples(num_ranks, global_num_samples, local_micro_batch_size) + TrainingNumberConversion.get_local_num_batches_from_num_samples( + num_ranks, global_num_samples, local_micro_batch_size + ) == expected ) @@ -28,7 +30,7 @@ def test_get_local_num_batches_from_num_tokens( num_ranks: int, global_num_tokens: int, sequence_length: int, local_micro_batch_size: int, expected: int ): assert ( - NumberConversion.get_local_num_batches_from_num_tokens( + TrainingNumberConversion.get_local_num_batches_from_num_tokens( num_ranks, global_num_tokens, sequence_length, local_micro_batch_size ) == expected @@ -47,7 +49,7 @@ def test_get_num_steps_from_num_samples( expected: int, ): assert ( - NumberConversion.get_num_steps_from_num_samples( + TrainingNumberConversion.get_num_steps_from_num_samples( num_ranks, local_micro_batch_size, global_num_samples, gradient_accumulation_steps ) == expected @@ -76,7 +78,7 @@ def test_get_num_steps_from_num_tokens( expected: int, ): assert ( - NumberConversion.get_num_steps_from_num_tokens( + TrainingNumberConversion.get_num_steps_from_num_tokens( num_ranks, local_micro_batch_size, global_num_tokens, sequence_length, gradient_accumulation_steps ) == expected @@ -101,7 +103,7 @@ def test_get_num_tokens_from_num_steps( expected: int, ): assert ( - NumberConversion.get_num_tokens_from_num_steps( + TrainingNumberConversion.get_num_tokens_from_num_steps( num_steps=num_steps, num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, @@ -141,9 +143,9 @@ def test_get_last_step_from_checkpoint_path(checkpoint_path: Path, expected: int if expected_exception: # Expecting an exception for this test case with pytest.raises(expected_exception): - NumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path) else: - assert NumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected + assert TrainingNumberConversion.get_last_step_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected @pytest.mark.parametrize( @@ -175,9 +177,12 @@ def test_get_num_seen_steps_from_checkpoint_path(checkpoint_path: Path, expected if expected_exception: # Expecting an exception for this test case with pytest.raises(expected_exception): - NumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) else: - assert NumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected + assert ( + TrainingNumberConversion.get_num_seen_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) + == expected + ) @pytest.mark.parametrize( @@ -211,10 +216,10 @@ def test_get_global_num_seen_tokens_from_checkpoint_path( if expected_exception: # Expecting an exception for this test case with pytest.raises(expected_exception): - NumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) else: assert ( - NumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_global_num_seen_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected ) @@ -250,10 +255,10 @@ def test_get_global_num_target_tokens_from_checkpoint_path( if expected_exception: # Expecting an exception for this test case with pytest.raises(expected_exception): - NumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) else: assert ( - NumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_global_num_target_tokens_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected ) @@ -287,9 +292,12 @@ def test_get_num_target_steps_from_checkpoint_path(checkpoint_path: Path, expect if expected_exception: # Expecting an exception for this test case with pytest.raises(expected_exception): - NumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) + TrainingNumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) else: - assert NumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) == expected + assert ( + TrainingNumberConversion.get_num_target_steps_from_checkpoint_path(checkpoint_path=checkpoint_path) + == expected + ) @pytest.mark.parametrize( @@ -336,7 +344,7 @@ def test_get_num_tokens_from_packed_mem_map_dataset_continuous( ) assert ( - NumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous( + TrainingNumberConversion.get_num_tokens_from_packed_mem_map_dataset_continuous( dataset_path=dataset_path, sequence_length=sequence_length, num_ranks=num_ranks, @@ -369,7 +377,7 @@ def test_num_steps_from_raw_dataset_index( with open(raw_index_path, "rb") as f: index_length = len(pickle.load(f)) - num_steps_from_number_conversion = NumberConversion.get_num_steps_from_raw_dataset_index( + num_steps_from_number_conversion = TrainingNumberConversion.get_num_steps_from_raw_dataset_index( raw_index_path=raw_index_path, num_ranks=num_ranks, local_micro_batch_size=local_micro_batch_size, diff --git a/tutorials/getting_started/example_config.yaml b/tutorials/getting_started/example_config.yaml index f1737f940..965302e36 100644 --- a/tutorials/getting_started/example_config.yaml +++ b/tutorials/getting_started/example_config.yaml @@ -28,7 +28,7 @@ settings: sequence_length: 512 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -37,7 +37,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/tutorials/library_usage/config_lorem_ipsum.yaml b/tutorials/library_usage/config_lorem_ipsum.yaml index 915e0ebd0..f8bdbb517 100644 --- a/tutorials/library_usage/config_lorem_ipsum.yaml +++ b/tutorials/library_usage/config_lorem_ipsum.yaml @@ -29,7 +29,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_num_steps config: num_steps: ${settings.training_target.num_target_steps} @@ -38,7 +38,7 @@ settings: sequence_length: ${settings.step_profile.sequence_length} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_raw_dataset_index config: raw_index_path: ${settings.paths.index_path} diff --git a/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml b/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml index 166d25fb5..bd91e965c 100644 --- a/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml +++ b/tutorials/modalities_in_15_mins/configs/pretraining_config.yaml @@ -27,7 +27,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -36,7 +36,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/tutorials/warmstart/configs/pre_training_config.yaml b/tutorials/warmstart/configs/pre_training_config.yaml index 30db4adf6..e3e83a64a 100644 --- a/tutorials/warmstart/configs/pre_training_config.yaml +++ b/tutorials/warmstart/configs/pre_training_config.yaml @@ -27,7 +27,7 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: num_tokens_from_packed_mem_map_dataset_continuous config: dataset_path: ${settings.paths.train_dataset_path} @@ -36,7 +36,7 @@ settings: local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_steps_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} diff --git a/tutorials/warmstart/configs/warmstart_config.yaml b/tutorials/warmstart/configs/warmstart_config.yaml index 1858d9a11..6fc68619a 100644 --- a/tutorials/warmstart/configs/warmstart_config.yaml +++ b/tutorials/warmstart/configs/warmstart_config.yaml @@ -27,28 +27,28 @@ settings: sequence_length: 256 training_target: num_target_tokens: - component_key: number_conversion + component_key: training_number_conversion variant_key: global_num_target_tokens_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} num_target_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_target_steps_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} training_progress: global_num_seen_tokens: # used below - component_key: number_conversion + component_key: training_number_conversion variant_key: global_num_seen_tokens_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} num_seen_steps: # for the batch progress subscriber - component_key: number_conversion + component_key: training_number_conversion variant_key: num_seen_steps_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} local_num_seen_batches: # for the dataloader - component_key: number_conversion + component_key: training_number_conversion variant_key: local_num_batches_from_num_tokens config: num_ranks: ${settings.cuda_env.world_size} @@ -56,7 +56,7 @@ settings: sequence_length: ${settings.step_profile.sequence_length} local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} last_step: # for the scheduler - component_key: number_conversion + component_key: training_number_conversion variant_key: last_step_from_checkpoint_path config: checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path}