From feabe5a8721a76eb755e37b85a310ae2e233dd28 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sat, 6 Dec 2025 12:30:25 +0100 Subject: [PATCH 01/14] refactor: removed unnecessary all-reduce ops and improved accuracy of time measurement --- src/modalities/evaluator.py | 32 ++++++++++++++++--------------- src/modalities/main.py | 1 + src/modalities/trainer.py | 37 ++++++++++++------------------------ src/modalities/util.py | 38 +------------------------------------ 4 files changed, 31 insertions(+), 77 deletions(-) diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 49bdaa11d..dff8dcdd8 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -3,6 +3,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh from modalities.batch import DatasetBatch, EvaluationResultBatch, InferenceResultBatch, ResultItem from modalities.dataloader.dataloader import LLMDataLoader @@ -10,9 +11,9 @@ from modalities.logging_broker.publisher import MessagePublisher from modalities.models.model import model_predict_batch from modalities.models.parallelism.pipeline_parallelism import Pipeline +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_degree from modalities.running_env.fsdp.reducer import Reducer -from modalities.trainer import ThroughputAggregationKeys -from modalities.util import Aggregator, TimeRecorder +from modalities.util import TimeRecorder class Evaluator: @@ -22,6 +23,7 @@ def __init__( self, progress_publisher: MessagePublisher[ProgressUpdate], evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], + device_mesh: DeviceMesh | None, ) -> None: """Initializes the Evaluator class. @@ -31,6 +33,14 @@ def __init__( """ self.progress_publisher = progress_publisher self.evaluation_result_publisher = evaluation_result_publisher + if device_mesh is not None: + self.dp_degree = get_parallel_degree( + device_mesh, [ParallelismDegrees.DP_REPLICATE, ParallelismDegrees.DP_SHARD] + ) + self.pp_degree = get_parallel_degree(device_mesh, [ParallelismDegrees.PP]) + else: # TODO: we can remove the else part once we refactored out FSDP1 + self.dp_degree = dist.get_world_size() + self.pp_degree = 1 def evaluate_batch( self, @@ -102,6 +112,7 @@ def evaluate( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") for data_loader in data_loaders: + local_num_seen_samples = 0 cumulated_loss = torch.zeros(3).to(device) Evaluator._publish_progress( @@ -109,7 +120,6 @@ def evaluate( num_eval_steps_done=0, # Reset progress bar dataloader_tag=data_loader.dataloader_tag, ) - throughput_aggregator = Aggregator[ThroughputAggregationKeys]() with TimeRecorder() as forward_backward_timer_recorder: for batch_id, batch in enumerate(data_loader): batch_loss = self.evaluate_batch( @@ -123,10 +133,7 @@ def evaluate( if batch_loss is not None: cumulated_loss[0] += batch_loss.item() # sum up batch loss cumulated_loss[1] += 1 - batch_length_tensor = torch.tensor(len(batch)).to(device) - throughput_aggregator.add_value( - key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor - ) + local_num_seen_samples += torch.tensor(len(batch)).to(device) Evaluator._publish_progress( progress_publisher=self.progress_publisher, @@ -141,14 +148,9 @@ def evaluate( ) forward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to(device) - throughput_aggregator.add_value( - key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time - ) - synced_num_samples = throughput_aggregator.get_all_reduced_value(ThroughputAggregationKeys.NUM_SAMPLES) - synced_forward_backward_time = throughput_aggregator.get_all_reduced_value( - ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX - ) - num_samples_per_second = synced_num_samples / synced_forward_backward_time + global_num_seen_samples = local_num_seen_samples * self.dp_degree + + num_samples_per_second = global_num_seen_samples / forward_backward_time evaluation_result = EvaluationResultBatch( losses={loss_fun.tag: ResultItem(total_loss, decimal_places=2)}, diff --git a/src/modalities/main.py b/src/modalities/main.py index 85a80731e..1fc46fdab 100644 --- a/src/modalities/main.py +++ b/src/modalities/main.py @@ -186,6 +186,7 @@ def run(self, components: TrainingComponentsInstantiationModel): evaluator = Evaluator( progress_publisher=progress_publisher, evaluation_result_publisher=evaluation_result_publisher, + device_mesh=components.device_mesh, ) # Gym diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 731802c4c..ac00b41b8 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -21,7 +21,7 @@ from modalities.running_env.fsdp.reducer import Reducer from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF from modalities.training.training_progress import TrainingProgress -from modalities.util import Aggregator, TimeRecorder, print_rank_0 +from modalities.util import TimeRecorder, print_rank_0 from modalities.utils.mfu import MFUCalculatorABC @@ -73,7 +73,7 @@ def __init__( device_mesh, [ParallelismDegrees.DP_REPLICATE, ParallelismDegrees.DP_SHARD] ) self.pp_degree = get_parallel_degree(device_mesh, [ParallelismDegrees.PP]) - else: + else: # TODO: we can remove the else part once we refactored out FSDP1 self.dp_degree = dist.get_world_size() self.pp_degree = 1 self.progress_publisher = progress_publisher @@ -201,10 +201,10 @@ def train( lr_scheduler = app_state.lr_scheduler model.train() + local_num_seen_samples = 0 cumulated_losses = self._reset_tracked_losses() # throughput - thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # batch loop @@ -246,7 +246,6 @@ def train( micro_batch_id=micro_batch_id, scheduled_pipeline=scheduled_pipeline, ) - forward_backward_time_recorder.stop() training_progress.num_seen_steps_current_run = num_train_steps_done training_progress.num_seen_tokens_current_run = self.global_num_tokens_per_train_step * num_train_steps_done @@ -261,8 +260,7 @@ def train( if gradient_norm_score is not None: gradient_norm_scores.append(gradient_norm_score.item()) - batch_length_tensor = torch.tensor(len(batch)).to(device) - thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) + local_num_seen_samples += torch.tensor(len(batch)).to(device) self._publish_progress( progress_publisher=self.progress_publisher, @@ -271,24 +269,17 @@ def train( ) # Check if model performance should be logged if training_progress.num_seen_steps_total % training_log_interval_in_steps == 0 and step_performed: + forward_backward_time_recorder.stop() forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device) forward_backward_time_recorder.reset() + forward_backward_time_recorder.start() + + global_num_seen_samples = local_num_seen_samples * self.dp_degree + local_num_seen_samples = 0 + global_num_samples_per_second = global_num_seen_samples / forward_backward_time - thoughput_aggregator.add_value( - key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time - ) - # we only want to sync the num samples across data parallel ranks - # so we divide the world size by the dp degree - synced_num_samples = thoughput_aggregator.get_all_reduced_value( - ThroughputAggregationKeys.NUM_SAMPLES - ) / (dist.get_world_size() / self.dp_degree) - synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value( - ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX - ) - synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time # TODO: insert reducer from outside so Trainer is independent of FSDP # add the loss and gradient norm for the LAST batch - cumulated_losses[1] = batch_loss.item() if batch_loss is not None else 0.0 reduced_losses = Reducer.reduce( @@ -319,7 +310,7 @@ def train( gradient_norm_scores = [] mfu_score = torch.tensor(-1.0) if self.mfu_calculator is not None: - mfu_score = self.mfu_calculator.compute(num_samples_per_second=synced_num_samples_per_second) + mfu_score = self.mfu_calculator.compute(num_samples_per_second=global_num_samples_per_second) # Collect peak memory depending on device type. On CPU we fall back to RSS (if available) or -1. if device.type == "cuda": @@ -339,7 +330,7 @@ def train( metrics=metrics, # TODO: hardcoded metric key throughput_metrics={ - "train samples/s": ResultItem(synced_num_samples_per_second, 1), + "train samples/s": ResultItem(global_num_samples_per_second, 1), "train mfu (16-bit)": ResultItem(mfu_score, 2), "lr mean": ResultItem(torch.tensor(lr_scheduler.get_last_lr()).mean()), "peak memory rank 0 (MB)": ResultItem(torch.tensor(peak_memory_MB), 2), @@ -352,15 +343,11 @@ def train( evaluation_result_publisher=self.evaluation_result_publisher, evaluation_result=training_metrics, ) - thoughput_aggregator.remove_keys() cumulated_losses = self._reset_tracked_losses() if step_performed: evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) checkpointing_callback(training_progress=training_progress) - # we start the time recoder here again to also capture the time spend loading - # via the dataloader. - forward_backward_time_recorder.start() def _reset_tracked_losses(self): # Initializes and returns a tensor representing the cumulated loss and gradient norm. diff --git a/src/modalities/util.py b/src/modalities/util.py index 4bff43859..3a2706e6e 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -5,7 +5,7 @@ from enum import Enum from pathlib import Path from types import TracebackType -from typing import Callable, Generic, Optional, Type, TypeVar +from typing import Optional, Type, TypeVar import torch import torch.distributed as dist @@ -19,7 +19,6 @@ from modalities.exceptions import TimeRecorderStateError from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, has_parallelism_method -from modalities.running_env.fsdp.reducer import Reducer from modalities.utils.typing_utils import FSDPX @@ -283,41 +282,6 @@ def __repr__(self) -> str: return f"{self.delta_t}s" -T = TypeVar("T") - - -class Aggregator(Generic[T]): - def __init__(self): - self.key_to_value: dict[T, torch.Tensor] = {} - - def add_value(self, key: T, value: torch.Tensor): - if key not in self.key_to_value: - self.key_to_value[key] = value - else: - self.key_to_value[key] += value - - def remove_key(self, key: T): - self.key_to_value.pop(key) - - def remove_keys(self): - self.key_to_value = {} - - def get_all_reduced_value( - self, - key: T, - reduce_operation: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM, - postprocessing_fun: None | Callable[[torch.Tensor], torch.Tensor] = None, - ) -> torch.Tensor: - # we clone the value so that we can always resync the value without side-effects - cloned_value = self.key_to_value[key].clone() - value = Reducer.reduce( - tensor=cloned_value, - operation=reduce_operation, - post_processing_fun=postprocessing_fun, # lambda t: t[0] / t[1], - ) - return value - - def get_module_class_from_name(module: torch.nn.Module, name: str) -> Type[torch.nn.Module] | None: """From Accelerate source code (https://github.com/huggingface/accelerate/blob/1f7a79b428749f45187ec69485f2c966fe21926e/src/accelerate/utils/dataclasses.py#L1902) From a4775ad1c9869c7944309da6cec7385a9c984ca6 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sat, 6 Dec 2025 12:32:13 +0100 Subject: [PATCH 02/14] chore: added documentation and renamed pytorch rms norm key --- src/modalities/models/gpt2/gpt2_model.py | 11 ++++++++--- src/modalities/registry/components.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index b50790d29..4fc97061e 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -142,9 +142,9 @@ def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq self._cos_cached = None self._sin_cached = None - def rotate_half(self, x): + def rotate_half(self, x: torch.Tensor): """ - Rearange tentor elements. + Rearrange tensor elements. Args: x (torch.Tensor): The input tensor. @@ -166,7 +166,9 @@ def _update_cos_sin_tables(self, x): self._seq_len_cached = seq_len t = torch.arange(x.shape[self.seq_length_dim], device=x.device, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + emb = torch.cat((freqs, freqs), dim=-1).to( + x.device + ) # here, we combine the two matrices (not zipping them). self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) @@ -190,6 +192,9 @@ def apply_rotary_pos_emb(self, x, cos, sin): cos = cos[:, :, : x.shape[self.seq_length_dim], :] sin = sin[:, :, : x.shape[self.seq_length_dim], :] + # the rotation is not really a rotation in higher dimensions, + # It merely swaps and negates certain dimensions to make + # the rotation below work return (x * cos) + (self.rotate_half(x) * sin) def forward( diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 27358c97f..eaccbf2d3 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -333,7 +333,7 @@ class ComponentEntity: # layer norms ComponentEntity("layer_norm", "rms_norm", RMSLayerNorm, RMSLayerNormConfig), ComponentEntity("layer_norm", "layer_norm", nn.LayerNorm, LayerNormConfig), - ComponentEntity("layer_norm", "rms_norm_pytorch", nn.RMSNorm, PytorchRMSLayerNormConfig), + ComponentEntity("layer_norm", "pytorch_rms_norm", nn.RMSNorm, PytorchRMSLayerNormConfig), # gradient clippers ComponentEntity("gradient_clipper", "fsdp1", FSDP1GradientClipper, FSDP1GradientClipperConfig), ComponentEntity( From 719e35e1d96e1931a59bca372aa4c67aeb3b7c99 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sat, 6 Dec 2025 12:33:30 +0100 Subject: [PATCH 03/14] feat: added timestamp and dtype to debugged model for input/output activatinos and weights --- src/modalities/models/model_factory.py | 28 +++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 3acb17f95..77cc187d9 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -1,5 +1,6 @@ import itertools import json +import time from dataclasses import asdict, dataclass from functools import partial from pathlib import Path @@ -420,6 +421,7 @@ class TensorStats: global_shape: list[int] local_shape: list[int] + dtype: str is_dtensor: bool nan_count: int inf_count: int @@ -456,6 +458,7 @@ def get_tensor_stats(tensor: torch.Tensor) -> TensorStats: tensor_stats = TensorStats( global_shape=list(tensor.shape), local_shape=list(local_tensor.shape), + dtype=str(dtype), is_dtensor=isinstance(tensor, dist.tensor.DTensor), nan_count=torch.isnan(local_tensor).sum().item(), inf_count=torch.isinf(local_tensor).sum().item(), @@ -466,7 +469,9 @@ def get_tensor_stats(tensor: torch.Tensor) -> TensorStats: ) return tensor_stats - def write_out_tensor_stats(tensor_stats: TensorStats, counter: int, hook_type: str, tensor_tag: str, rank: int): + def write_out_tensor_stats( + tensor_stats: TensorStats, counter: int, hook_type: str, tensor_tag: str, rank: int, timestamp_ns: int + ): """Write out tensor statistics to a file.""" with open(logging_file_path, "a", encoding="utf-8") as f: tensor_stats_dict = asdict(tensor_stats) @@ -476,11 +481,13 @@ def write_out_tensor_stats(tensor_stats: TensorStats, counter: int, hook_type: s **tensor_stats_dict, "counter": counter, "rank": rank, + "timestamp_ns": timestamp_ns, } f.write(json.dumps(tensor_stats_dict) + "\n") def pre_forward_hook(module: nn.Module, forward_input, counter: CounterRef, log_interval_steps: int): + timestamp_ns = time.perf_counter_ns() if log_interval_steps > 0 and counter.value % log_interval_steps != 0: counter.value += 1 return @@ -492,7 +499,9 @@ def pre_forward_hook(module: nn.Module, forward_input, counter: CounterRef, log_ for forward_input in forward_inputs: tensor_stats = get_tensor_stats(forward_input) - write_out_tensor_stats(tensor_stats, counter.value, "forward_input", module._debug_name, rank) + write_out_tensor_stats( + tensor_stats, counter.value, "forward_input", module._debug_name, rank, timestamp_ns + ) # Retrieves statistics of the module's parameters before forward pass. for name, param in module.named_parameters(recurse=False): @@ -504,10 +513,14 @@ def pre_forward_hook(module: nn.Module, forward_input, counter: CounterRef, log_ hook_type="forward_weights", tensor_tag=full_name, rank=rank, + timestamp_ns=timestamp_ns, ) counter.value += 1 - def forward_hook(module: nn.Module, foward_input, forward_output, counter: CounterRef, log_interval_steps: int): + def forward_hook( + module: nn.Module, forward_input, forward_output, counter: CounterRef, log_interval_steps: int + ): + timestamp_ns = time.perf_counter_ns() if log_interval_steps > 0 and counter.value % log_interval_steps != 0: counter.value += 1 return @@ -519,17 +532,22 @@ def forward_hook(module: nn.Module, foward_input, forward_output, counter: Count for out in forward_outputs: tensor_stats = get_tensor_stats(out) - write_out_tensor_stats(tensor_stats, counter.value, "forward_output", module._debug_name, rank) + write_out_tensor_stats( + tensor_stats, counter.value, "forward_output", module._debug_name, rank, timestamp_ns + ) counter.value += 1 def backward_hook(module, grad_input, grad_output, counter: CounterRef, log_interval_steps: int): + timestamp_ns = time.perf_counter_ns() if log_interval_steps > 0 and counter.value % log_interval_steps != 0: counter.value += 1 return for grad_out in grad_output: tensor_stats = get_tensor_stats(grad_out) - write_out_tensor_stats(tensor_stats, counter.value, "backward_output", module._debug_name, rank) + write_out_tensor_stats( + tensor_stats, counter.value, "backward_output", module._debug_name, rank, timestamp_ns + ) counter.value += 1 def register_hooks_recursively(module: nn.Module, prefix: str = ""): From 03354c1c543e851638815bf80e1a0f1cac7b0dc7 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sat, 6 Dec 2025 13:14:56 +0100 Subject: [PATCH 04/14] feat: steppable component can now perform backward pass and optimizer steps --- .../utils/profilers/modalities_profiler.py | 132 +++++++++++++++--- .../profilers/steppable_component_configs.py | 2 + .../utils/profilers/steppable_components.py | 15 +- 3 files changed, 124 insertions(+), 25 deletions(-) diff --git a/src/modalities/utils/profilers/modalities_profiler.py b/src/modalities/utils/profilers/modalities_profiler.py index 339c0fb09..58e3a7fa9 100644 --- a/src/modalities/utils/profilers/modalities_profiler.py +++ b/src/modalities/utils/profilers/modalities_profiler.py @@ -1,4 +1,4 @@ -import os +import pickle import shutil from dataclasses import dataclass from pathlib import Path @@ -31,6 +31,90 @@ class CustomComponentRegisterable: custom_config: type +class SteppableProfilerIF: + def __enter__(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_value, traceback): + raise NotImplementedError + + def step(self): + raise NotImplementedError + + +class SteppableMemoryProfiler(SteppableProfilerIF): + MEMORY_SNAPSHOT_MAX_ENTRIES = 100_000 + + def __init__(self, memory_snapshot_path: Path, num_wait_steps: int, num_warmup_steps: int, num_active_steps: int): + self._memory_snapshot_path = memory_snapshot_path + self._curr_step = None + self._num_wait_steps = num_wait_steps + self._num_warmup_steps = num_warmup_steps + self._num_active_steps = num_active_steps + + def __enter__(self): + self._curr_step = 0 + # start recording memory history if there is no wait / warmup steps + if self._curr_step == self._num_wait_steps + self._num_warmup_steps and self._num_active_steps > 0: + torch.cuda.memory._record_memory_history(max_entries=SteppableMemoryProfiler.MEMORY_SNAPSHOT_MAX_ENTRIES) + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self._curr_step is None: + raise RuntimeError("SteppableMemoryProfilerContext exited without being entered") + if self._curr_step < self._num_wait_steps + self._num_warmup_steps + self._num_active_steps: + # if we exit before finishing all steps, dump the memory snapshot + raise RuntimeError("SteppableMemoryProfilerContext exited before finishing all steps") + return + + def step(self): + if self._curr_step is None: + raise RuntimeError("SteppableMemoryProfilerContext.step() called outside of context manager") + self._curr_step += 1 + if self._curr_step < self._num_wait_steps + self._num_warmup_steps: + return + elif self._curr_step == self._num_wait_steps + self._num_warmup_steps: + torch.cuda.memory._record_memory_history(max_entries=SteppableMemoryProfiler.MEMORY_SNAPSHOT_MAX_ENTRIES) + elif ( + self._curr_step == self._num_wait_steps + self._num_warmup_steps + self._num_active_steps + and self._num_active_steps > 0 + ): + with open(self._memory_snapshot_path, "wb") as output: + pickle.dump(torch.cuda.memory._snapshot(), output) + + +class ProfilerListContext(SteppableProfilerIF): + def __init__(self, profiler_cms: list[SteppableProfilerIF]): + self.profiler_cms = profiler_cms + self._entered = None + + def __enter__(self): + if self._entered is not None: + raise RuntimeError("ProfilerListContext entered multiple times without exiting") + self._entered = [] + for profiler_cm in self.profiler_cms: + return_val = profiler_cm.__enter__() + if return_val is not None: + self._entered.append(return_val) + else: + self._entered.append(profiler_cm) + + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self._entered is None: + raise RuntimeError("ProfilerListContext exited without being entered") + for profiler_cm in self._entered: + profiler_cm.__exit__(exc_type, exc_value, traceback) + self._entered = None + + def step(self): + if self._entered is None: + raise RuntimeError("ProfilerListContext.step() called outside of context manager") + for profiler_cm in self._entered: + profiler_cm.step() + + class ModalitiesProfilerStarter: """Starter class to run profiling either in single process or distributed mode.""" @@ -71,7 +155,6 @@ def run_distributed( global_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() - local_rank = int(os.environ["LOCAL_RANK"]) ModalitiesProfilerStarter._run_helper( config_file_path=config_file_path, @@ -79,7 +162,6 @@ def run_distributed( num_wait_steps=num_wait_steps, num_warmup_steps=num_warmup_steps, experiment_folder_path=experiment_root_path / experiment_id, - local_rank=local_rank, global_rank=global_rank, world_size=world_size, profiled_ranks=profiled_ranks, @@ -122,7 +204,6 @@ def run_single_process( global_rank = 0 world_size = 1 - local_rank = 0 profiled_ranks = [0] ModalitiesProfilerStarter._run_helper( @@ -133,7 +214,6 @@ def run_single_process( experiment_folder_path=experiment_root_path / experiment_id, global_rank=global_rank, world_size=world_size, - local_rank=local_rank, profiled_ranks=profiled_ranks, custom_component_registerables=custom_component_registerables, ) @@ -161,22 +241,35 @@ def _run_helper( experiment_folder_path: Path, profiled_ranks: list[int], global_rank: int, - local_rank: int, world_size: int, custom_component_registerables: list[CustomComponentRegisterable] | None = None, ): - # build profiler - profiler_activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] - profile_context_manager = profile( + # build profilers + profiler_activities = [ProfilerActivity.CUDA] # ProfilerActivity.CPU, + kernel_profiler = profile( activities=profiler_activities, schedule=schedule(wait=num_wait_steps, warmup=num_warmup_steps, active=num_measurement_steps), - record_shapes=True, - profile_memory=True, - with_flops=True, - with_stack=True, - with_modules=True, + record_shapes=False, + profile_memory=False, + with_flops=False, + with_stack=False, + with_modules=False, + # record_shapes=True, + # profile_memory=True, + # with_flops=True, + # with_stack=True, + # with_modules=True, ) + SteppableMemoryProfiler( + memory_snapshot_path=experiment_folder_path / f"memory_snapshot_ranks_{world_size}_rank_{global_rank}.pkl", + num_wait_steps=num_wait_steps, + num_warmup_steps=num_warmup_steps, + num_active_steps=num_measurement_steps, + ) + + profile_context_manager = ProfilerListContext(profiler_cms=[kernel_profiler]) # , memory_profiler] + # register custom components and build components from config # workaround to avoid triggering synchronization of experiment id in single process experiment_id = experiment_folder_path.name if world_size == 1 else None @@ -199,15 +292,12 @@ def _run_helper( show_progress=(global_rank == profiled_ranks[0]), # only show progress on a single rank that is profiled ) trace_output_path = experiment_folder_path / f"profiler_trace_ranks_{world_size}_rank_{global_rank}.json" - memory_output_path = experiment_folder_path / f"profiler_memory_ranks_{world_size}_rank_{global_rank}.html" summary_output_path = experiment_folder_path / f"profiler_summary_ranks_{world_size}_rank_{global_rank}.txt" ModalitiesProfiler.export_profiling_results( - profiler_context_manager=profile_context_manager, + profiler_context_manager=kernel_profiler, trace_output_path=trace_output_path, - memory_output_path=memory_output_path, summary_output_path=summary_output_path, - local_rank=local_rank, global_rank=global_rank, profiled_ranks=profiled_ranks, ) @@ -218,7 +308,7 @@ class ModalitiesProfiler: def profile( steppable_component: SteppableComponentIF, num_total_steps: int, - profile_context_manager: torch.profiler.profile, + profile_context_manager: SteppableProfilerIF, show_progress: bool = False, ) -> None: """Profile a steppable component using the provided profiler context manager. @@ -243,10 +333,8 @@ def profile( def export_profiling_results( profiler_context_manager: torch.profiler.profile, trace_output_path: Path, - memory_output_path: Path, summary_output_path: Path, global_rank: int, - local_rank: int, profiled_ranks: list[int], ) -> None: """Export profiling results to specified output paths if the current rank is in profiled_ranks. @@ -263,8 +351,6 @@ def export_profiling_results( if global_rank in profiled_ranks: logger.info(f"Saving profiling results for rank {global_rank}...") profiler_context_manager.export_chrome_trace(trace_output_path.as_posix()) - device = local_rank if local_rank is not None else None - profiler_context_manager.export_memory_timeline(memory_output_path.as_posix(), device=device) table = profiler_context_manager.key_averages().table() with open(summary_output_path, "w", encoding="utf-8") as f: f.write(table) diff --git a/src/modalities/utils/profilers/steppable_component_configs.py b/src/modalities/utils/profilers/steppable_component_configs.py index 29248f0b6..03ae205f8 100644 --- a/src/modalities/utils/profilers/steppable_component_configs.py +++ b/src/modalities/utils/profilers/steppable_component_configs.py @@ -3,6 +3,7 @@ from modalities.config.pydantic_if_types import ( PydanticDatasetBatchGeneratorIFType, PydanticLossIFType, + PydanticOptimizerIFType, PydanticPytorchModuleType, ) @@ -11,3 +12,4 @@ class SteppableForwardPassConfig(BaseModel): model: PydanticPytorchModuleType dataset_batch_generator: PydanticDatasetBatchGeneratorIFType loss_fn: PydanticLossIFType | None = None + optimizer: PydanticOptimizerIFType | None = None diff --git a/src/modalities/utils/profilers/steppable_components.py b/src/modalities/utils/profilers/steppable_components.py index 526f79f73..40d025a32 100644 --- a/src/modalities/utils/profilers/steppable_components.py +++ b/src/modalities/utils/profilers/steppable_components.py @@ -15,7 +15,13 @@ class SteppableForwardPass(SteppableComponentIF): The component is used for profiling. """ - def __init__(self, model: nn.Module, dataset_batch_generator: DatasetBatchGeneratorIF, loss_fn: Loss | None = None): + def __init__( + self, + model: nn.Module, + dataset_batch_generator: DatasetBatchGeneratorIF, + loss_fn: Loss | None = None, + optimizer: torch.optim.Optimizer | None = None, + ): """Initializes the SteppableForwardPass component. Args: @@ -27,6 +33,7 @@ def __init__(self, model: nn.Module, dataset_batch_generator: DatasetBatchGenera self.loss_fn = loss_fn self.dataset_batch_generator = dataset_batch_generator self.device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + self.optimizer = optimizer def step( self, @@ -37,4 +44,8 @@ def step( predictions = self.model(batch.samples) result_batch = InferenceResultBatch(targets=batch.targets, predictions=predictions) if self.loss_fn is not None: - self.loss_fn(result_batch) + loss = self.loss_fn(result_batch) + loss.backward() + if self.optimizer is not None: + self.optimizer.step() + self.optimizer.zero_grad() From 9e661dc714e47d1238090397b69e85281eb0a454 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sat, 13 Dec 2025 15:54:06 +0100 Subject: [PATCH 05/14] feat: added fused and foreach options to Adam and AdamW optimizers --- src/modalities/config/config.py | 4 ++++ src/modalities/optimizers/optimizer_factory.py | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 5ae8c0822..607d902bf 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -150,6 +150,8 @@ class AdamOptimizerConfig(BaseModel): eps: float weight_decay: float weight_decay_groups_excluded: list[str] + foreach: bool | None = None + fused: bool | None = None class AdamWOptimizerConfig(BaseModel): @@ -159,6 +161,8 @@ class AdamWOptimizerConfig(BaseModel): eps: float weight_decay: float weight_decay_groups_excluded: list[str] + foreach: bool | None = None + fused: bool | None = None class DummyLRSchedulerConfig(BaseModel): diff --git a/src/modalities/optimizers/optimizer_factory.py b/src/modalities/optimizers/optimizer_factory.py index 5679c99f4..9b92a1f8e 100644 --- a/src/modalities/optimizers/optimizer_factory.py +++ b/src/modalities/optimizers/optimizer_factory.py @@ -26,10 +26,12 @@ def get_adam( eps: float, weight_decay: float, weight_decay_groups_excluded: list[str], + foreach: bool | None, + fused: bool | None, wrapped_model: nn.Module, ) -> Optimizer: optimizer_groups = get_optimizer_groups(wrapped_model, weight_decay, weight_decay_groups_excluded) - optimizer = Adam(params=optimizer_groups, lr=lr, betas=betas, eps=eps) + optimizer = Adam(params=optimizer_groups, lr=lr, betas=betas, eps=eps, foreach=foreach, fused=fused) return optimizer @staticmethod @@ -39,10 +41,12 @@ def get_adam_w( eps: float, weight_decay: float, weight_decay_groups_excluded: list[str], + foreach: bool | None, + fused: bool | None, wrapped_model: nn.Module, ) -> Optimizer: optimizer_groups = get_optimizer_groups(wrapped_model, weight_decay, weight_decay_groups_excluded) - optimizer = AdamW(params=optimizer_groups, lr=lr, betas=betas, eps=eps) + optimizer = AdamW(params=optimizer_groups, lr=lr, betas=betas, eps=eps, foreach=foreach, fused=fused) return optimizer @staticmethod From 3518d00b84b5a522fdee157d8a9364cba991cad8 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 19 Dec 2025 12:39:00 +0100 Subject: [PATCH 06/14] refactor: profilers are now components --- src/modalities/config/pydantic_if_types.py | 2 + src/modalities/registry/components.py | 32 +++ .../utils/profilers/modalities_profiler.py | 214 +----------------- .../utils/profilers/profiler_configs.py | 49 ++++ .../utils/profilers/profiler_factory.py | 100 ++++++++ src/modalities/utils/profilers/profilers.py | 207 +++++++++++++++++ 6 files changed, 398 insertions(+), 206 deletions(-) create mode 100644 src/modalities/utils/profilers/profiler_configs.py create mode 100644 src/modalities/utils/profilers/profiler_factory.py create mode 100644 src/modalities/utils/profilers/profilers.py diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index 92931dc42..81eb122dc 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -28,6 +28,7 @@ from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF from modalities.utils.mfu import MFUCalculatorABC from modalities.utils.profilers.batch_generator import DatasetBatchGeneratorIF +from modalities.utils.profilers.profilers import SteppableProfilerIF from modalities.utils.profilers.steppable_components import SteppableComponentIF @@ -90,3 +91,4 @@ def __get_pydantic_core_schema__( PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)] PydanticPipelineStageType = Annotated[PipelineStage, PydanticThirdPartyTypeIF(PipelineStage)] PydanticSteppableComponentIFType = Annotated[SteppableComponentIF, PydanticThirdPartyTypeIF(SteppableComponentIF)] +PydanticSteppableProfilerIFType = Annotated[SteppableProfilerIF, PydanticThirdPartyTypeIF(SteppableProfilerIF)] diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index eaccbf2d3..351fd7b33 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -135,6 +135,14 @@ NumTokensFromPackedMemMapDatasetContinuousConfig, ) from modalities.utils.profilers.batch_generator import RandomDatasetBatchGenerator, RandomDatasetBatchGeneratorConfig +from modalities.utils.profilers.profiler_configs import ( + SteppableCombinedProfilerConfig, + SteppableKernelProfilerConfig, + SteppableMemoryProfilerConfig, + SteppableNoProfilerConfig, +) +from modalities.utils.profilers.profiler_factory import ProfilerFactory +from modalities.utils.profilers.profilers import SteppableCombinedProfiler, SteppableNoProfiler from modalities.utils.profilers.steppable_component_configs import SteppableForwardPassConfig from modalities.utils.profilers.steppable_components import SteppableForwardPass @@ -431,4 +439,28 @@ class ComponentEntity: SteppableForwardPass, SteppableForwardPassConfig, ), + ComponentEntity( + "steppable_profiler", + "kernel_tracing", + ProfilerFactory.create_steppable_kernel_profiler, + SteppableKernelProfilerConfig, + ), + ComponentEntity( + "steppable_profiler", + "memory_tracing", + ProfilerFactory.create_steppable_memory_profiler, + SteppableMemoryProfilerConfig, + ), + ComponentEntity( + "steppable_profiler", + "no_profiler", + SteppableNoProfiler, + SteppableNoProfilerConfig, + ), + ComponentEntity( + "steppable_profiler", + "combined", + SteppableCombinedProfiler, + SteppableCombinedProfilerConfig, + ), ] diff --git a/src/modalities/utils/profilers/modalities_profiler.py b/src/modalities/utils/profilers/modalities_profiler.py index 58e3a7fa9..c065323ba 100644 --- a/src/modalities/utils/profilers/modalities_profiler.py +++ b/src/modalities/utils/profilers/modalities_profiler.py @@ -1,26 +1,24 @@ -import pickle import shutil from dataclasses import dataclass from pathlib import Path import torch from pydantic import BaseModel -from torch.profiler import ProfilerActivity, profile, schedule from tqdm import trange from modalities.config.config import ProcessGroupBackendType -from modalities.config.pydantic_if_types import PydanticSteppableComponentIFType +from modalities.config.pydantic_if_types import PydanticSteppableComponentIFType, PydanticSteppableProfilerIFType from modalities.main import Main from modalities.running_env.cuda_env import CudaEnv from modalities.util import get_experiment_id_from_config, get_synced_experiment_id_of_run from modalities.utils.logger_utils import get_logger -from modalities.utils.profilers.steppable_components import SteppableComponentIF logger = get_logger("modalities_profiler") class InstantiationModel(BaseModel): steppable_component: PydanticSteppableComponentIFType + profiler: PydanticSteppableProfilerIFType @dataclass @@ -31,101 +29,13 @@ class CustomComponentRegisterable: custom_config: type -class SteppableProfilerIF: - def __enter__(self): - raise NotImplementedError - - def __exit__(self, exc_type, exc_value, traceback): - raise NotImplementedError - - def step(self): - raise NotImplementedError - - -class SteppableMemoryProfiler(SteppableProfilerIF): - MEMORY_SNAPSHOT_MAX_ENTRIES = 100_000 - - def __init__(self, memory_snapshot_path: Path, num_wait_steps: int, num_warmup_steps: int, num_active_steps: int): - self._memory_snapshot_path = memory_snapshot_path - self._curr_step = None - self._num_wait_steps = num_wait_steps - self._num_warmup_steps = num_warmup_steps - self._num_active_steps = num_active_steps - - def __enter__(self): - self._curr_step = 0 - # start recording memory history if there is no wait / warmup steps - if self._curr_step == self._num_wait_steps + self._num_warmup_steps and self._num_active_steps > 0: - torch.cuda.memory._record_memory_history(max_entries=SteppableMemoryProfiler.MEMORY_SNAPSHOT_MAX_ENTRIES) - return self - - def __exit__(self, exc_type, exc_value, traceback): - if self._curr_step is None: - raise RuntimeError("SteppableMemoryProfilerContext exited without being entered") - if self._curr_step < self._num_wait_steps + self._num_warmup_steps + self._num_active_steps: - # if we exit before finishing all steps, dump the memory snapshot - raise RuntimeError("SteppableMemoryProfilerContext exited before finishing all steps") - return - - def step(self): - if self._curr_step is None: - raise RuntimeError("SteppableMemoryProfilerContext.step() called outside of context manager") - self._curr_step += 1 - if self._curr_step < self._num_wait_steps + self._num_warmup_steps: - return - elif self._curr_step == self._num_wait_steps + self._num_warmup_steps: - torch.cuda.memory._record_memory_history(max_entries=SteppableMemoryProfiler.MEMORY_SNAPSHOT_MAX_ENTRIES) - elif ( - self._curr_step == self._num_wait_steps + self._num_warmup_steps + self._num_active_steps - and self._num_active_steps > 0 - ): - with open(self._memory_snapshot_path, "wb") as output: - pickle.dump(torch.cuda.memory._snapshot(), output) - - -class ProfilerListContext(SteppableProfilerIF): - def __init__(self, profiler_cms: list[SteppableProfilerIF]): - self.profiler_cms = profiler_cms - self._entered = None - - def __enter__(self): - if self._entered is not None: - raise RuntimeError("ProfilerListContext entered multiple times without exiting") - self._entered = [] - for profiler_cm in self.profiler_cms: - return_val = profiler_cm.__enter__() - if return_val is not None: - self._entered.append(return_val) - else: - self._entered.append(profiler_cm) - - return self - - def __exit__(self, exc_type, exc_value, traceback): - if self._entered is None: - raise RuntimeError("ProfilerListContext exited without being entered") - for profiler_cm in self._entered: - profiler_cm.__exit__(exc_type, exc_value, traceback) - self._entered = None - - def step(self): - if self._entered is None: - raise RuntimeError("ProfilerListContext.step() called outside of context manager") - for profiler_cm in self._entered: - profiler_cm.step() - - class ModalitiesProfilerStarter: """Starter class to run profiling either in single process or distributed mode.""" @staticmethod def run_distributed( config_file_path: Path, - num_measurement_steps: int, - num_wait_steps: int, - num_warmup_steps: int, experiment_root_path: Path, - profiled_ranks: list[int], experiment_id: str | None = None, custom_component_registerables: list[CustomComponentRegisterable] | None = None, ): @@ -134,11 +44,7 @@ def run_distributed( Args: config_file_path (Path): Path to the configuration file. - num_measurement_steps (int): Number of measurement steps for profiling. - num_wait_steps (int): Number of wait steps before profiling starts. - num_warmup_steps (int): Number of warmup steps before measurement starts. experiment_root_path (Path): Root path to store experiment results. - profiled_ranks (list[int]): List of ranks to profile. experiment_id (str, optional): Experiment ID. If None, it will be generated. Defaults to None. custom_component_registerables (list[CustomComponentRegisterable], optional): List of custom components to register. Defaults to None. @@ -158,22 +64,15 @@ def run_distributed( ModalitiesProfilerStarter._run_helper( config_file_path=config_file_path, - num_measurement_steps=num_measurement_steps, - num_wait_steps=num_wait_steps, - num_warmup_steps=num_warmup_steps, experiment_folder_path=experiment_root_path / experiment_id, global_rank=global_rank, world_size=world_size, - profiled_ranks=profiled_ranks, custom_component_registerables=custom_component_registerables, ) @staticmethod def run_single_process( config_file_path: Path, - num_measurement_steps: int, - num_wait_steps: int, - num_warmup_steps: int, experiment_root_path: Path, experiment_id: str | None = None, custom_component_registerables: list[CustomComponentRegisterable] | None = None, @@ -186,9 +85,6 @@ def run_single_process( Args: config_file_path (Path): Path to the configuration file. - num_measurement_steps (int): Number of measurement steps for profiling. - num_wait_steps (int): Number of wait steps before profiling starts. - num_warmup_steps (int): Number of warmup steps before measurement starts. experiment_root_path (Path): Root path to store experiment results. experiment_id (str, optional): Experiment ID. If None, it will be generated. custom_component_registerables (list[CustomComponentRegisterable], optional): List of custom @@ -204,17 +100,12 @@ def run_single_process( global_rank = 0 world_size = 1 - profiled_ranks = [0] ModalitiesProfilerStarter._run_helper( config_file_path=config_file_path, - num_measurement_steps=num_measurement_steps, - num_wait_steps=num_wait_steps, - num_warmup_steps=num_warmup_steps, experiment_folder_path=experiment_root_path / experiment_id, global_rank=global_rank, world_size=world_size, - profiled_ranks=profiled_ranks, custom_component_registerables=custom_component_registerables, ) @@ -235,41 +126,11 @@ def _copy_config_to_experiment_folder( @staticmethod def _run_helper( config_file_path: Path, - num_measurement_steps: int, - num_wait_steps: int, - num_warmup_steps: int, experiment_folder_path: Path, - profiled_ranks: list[int], global_rank: int, world_size: int, custom_component_registerables: list[CustomComponentRegisterable] | None = None, ): - # build profilers - profiler_activities = [ProfilerActivity.CUDA] # ProfilerActivity.CPU, - kernel_profiler = profile( - activities=profiler_activities, - schedule=schedule(wait=num_wait_steps, warmup=num_warmup_steps, active=num_measurement_steps), - record_shapes=False, - profile_memory=False, - with_flops=False, - with_stack=False, - with_modules=False, - # record_shapes=True, - # profile_memory=True, - # with_flops=True, - # with_stack=True, - # with_modules=True, - ) - - SteppableMemoryProfiler( - memory_snapshot_path=experiment_folder_path / f"memory_snapshot_ranks_{world_size}_rank_{global_rank}.pkl", - num_wait_steps=num_wait_steps, - num_warmup_steps=num_warmup_steps, - num_active_steps=num_measurement_steps, - ) - - profile_context_manager = ProfilerListContext(profiler_cms=[kernel_profiler]) # , memory_profiler] - # register custom components and build components from config # workaround to avoid triggering synchronization of experiment id in single process experiment_id = experiment_folder_path.name if world_size == 1 else None @@ -283,74 +144,15 @@ def _run_helper( custom_config=registerable.custom_config, ) components: InstantiationModel = main_obj.build_components(components_model_type=InstantiationModel) + steppable_component = components.steppable_component + profiler_cm = components.profiler - # run profiling - ModalitiesProfiler.profile( - steppable_component=components.steppable_component, - num_total_steps=num_measurement_steps + num_wait_steps + num_warmup_steps, - profile_context_manager=profile_context_manager, - show_progress=(global_rank == profiled_ranks[0]), # only show progress on a single rank that is profiled - ) - trace_output_path = experiment_folder_path / f"profiler_trace_ranks_{world_size}_rank_{global_rank}.json" - summary_output_path = experiment_folder_path / f"profiler_summary_ranks_{world_size}_rank_{global_rank}.txt" - - ModalitiesProfiler.export_profiling_results( - profiler_context_manager=kernel_profiler, - trace_output_path=trace_output_path, - summary_output_path=summary_output_path, - global_rank=global_rank, - profiled_ranks=profiled_ranks, - ) - - -class ModalitiesProfiler: - @staticmethod - def profile( - steppable_component: SteppableComponentIF, - num_total_steps: int, - profile_context_manager: SteppableProfilerIF, - show_progress: bool = False, - ) -> None: - """Profile a steppable component using the provided profiler context manager. - - Args: - steppable_component (SteppableComponentIF): The steppable component to profile. - num_total_steps (int): Total number of steps to run. - profile_context_manager (profile): The profiler context manager. - show_progress (bool): Whether to show a progress bar. Defaults to False. - """ - if show_progress: - step_iterator = trange(num_total_steps, desc="Profiling steps") + if global_rank == 0: + step_iterator = trange(len(profiler_cm), desc="Profiling steps") else: - step_iterator = range(num_total_steps) + step_iterator = range(len(profiler_cm)) - with profile_context_manager as profiler: + with profiler_cm as profiler: for _ in step_iterator: steppable_component.step() profiler.step() - - @staticmethod - def export_profiling_results( - profiler_context_manager: torch.profiler.profile, - trace_output_path: Path, - summary_output_path: Path, - global_rank: int, - profiled_ranks: list[int], - ) -> None: - """Export profiling results to specified output paths if the current rank is in profiled_ranks. - - Args: - profiler_context_manager (profile): The profiler context manager. - trace_output_path (Path): Path to save the Chrome trace. - memory_output_path (Path): Path to save the memory timeline. - summary_output_path (Path): Path to save the summary table. - global_rank (int): The global rank of the current process. - local_rank (int): The local rank of the current process. - profiled_ranks (list[int]): List of ranks to profile. - """ - if global_rank in profiled_ranks: - logger.info(f"Saving profiling results for rank {global_rank}...") - profiler_context_manager.export_chrome_trace(trace_output_path.as_posix()) - table = profiler_context_manager.key_averages().table() - with open(summary_output_path, "w", encoding="utf-8") as f: - f.write(table) diff --git a/src/modalities/utils/profilers/profiler_configs.py b/src/modalities/utils/profilers/profiler_configs.py new file mode 100644 index 000000000..2a056a81c --- /dev/null +++ b/src/modalities/utils/profilers/profiler_configs.py @@ -0,0 +1,49 @@ +from pathlib import Path + +from pydantic import BaseModel + +from modalities.config.lookup_enum import LookupEnum +from modalities.config.pydantic_if_types import PydanticSteppableProfilerIFType + + +class ModalitiesProfilerActivity(LookupEnum): + CPU = "CPU" + CUDA = "CUDA" + + +class SteppableKernelProfilerConfig(BaseModel): + """Settings for the kernel profiler.""" + + num_wait_steps: int + num_warmup_steps: int + num_active_steps: int + profiler_activities: list[ModalitiesProfilerActivity] + record_shapes: bool + profile_memory: bool + with_flops: bool + with_stack: bool + with_modules: bool + output_folder_path: Path + tracked_ranks: list[int] | None = None + + +class SteppableMemoryProfilerConfig(BaseModel): + """Settings for the memory profiler.""" + + memory_snapshot_folder_path: Path + num_wait_steps: int + num_warmup_steps: int + num_active_steps: int + tracked_ranks: list[int] | None = None + + +class SteppableNoProfilerConfig(BaseModel): + """Settings for no profiler.""" + + pass + + +class SteppableCombinedProfilerConfig(BaseModel): + """Settings for combined profilers.""" + + profilers: list[PydanticSteppableProfilerIFType] diff --git a/src/modalities/utils/profilers/profiler_factory.py b/src/modalities/utils/profilers/profiler_factory.py new file mode 100644 index 000000000..0dfe335fc --- /dev/null +++ b/src/modalities/utils/profilers/profiler_factory.py @@ -0,0 +1,100 @@ +from pathlib import Path + +import torch + +from modalities.utils.profilers.profiler_configs import ModalitiesProfilerActivity +from modalities.utils.profilers.profilers import ( + SteppableKernelProfiler, + SteppableMemoryProfiler, + SteppableNoProfiler, + SteppableProfilerIF, +) + + +class ProfilerFactory: + """Factory class to create different types of profilers based on the provided settings.""" + + @staticmethod + def create_steppable_kernel_profiler( + num_wait_steps: int, + num_warmup_steps: int, + num_active_steps: int, + profiler_activities: list[ModalitiesProfilerActivity], + record_shapes: bool, + profile_memory: bool, + with_flops: bool, + with_stack: bool, + with_modules: bool, + output_folder_path: Path, + tracked_ranks: list[int] | None = None, + ) -> SteppableProfilerIF: + """Creates a steppable kernel profiler based on the provided settings.""" + torch.profiler.ProfilerActivity + if tracked_ranks is None: + tracked_ranks = [] + global_rank, world_size = ProfilerFactory._get_global_rank_and_world_size() + + profiler_activities_converted = [] + for activity in profiler_activities: + if activity == ModalitiesProfilerActivity.CPU: + profiler_activities_converted.append(torch.profiler.ProfilerActivity.CPU) + elif activity == ModalitiesProfilerActivity.CUDA: + profiler_activities_converted.append(torch.profiler.ProfilerActivity.CUDA) + + trace_output_path = output_folder_path / f"profiler_trace_ranks_{world_size}_rank_{global_rank}.json" + summary_output_path = output_folder_path / f"profiler_summary_ranks_{world_size}_rank_{global_rank}.txt" + + profiler = SteppableKernelProfiler( + num_wait_steps=num_wait_steps, + num_warmup_steps=num_warmup_steps, + num_active_steps=num_active_steps, + profiler_activities=profiler_activities_converted, + record_shapes=record_shapes, + profile_memory=profile_memory, + with_flops=with_flops, + with_stack=with_stack, + with_modules=with_modules, + trace_output_path=trace_output_path, + summary_output_path=summary_output_path, + ) + + if global_rank not in tracked_ranks: + num_steps = len(profiler) + return SteppableNoProfiler(num_steps=num_steps) + else: + return profiler + + @staticmethod + def create_steppable_memory_profiler( + memory_snapshot_folder_path: Path, + num_wait_steps: int, + num_warmup_steps: int, + num_active_steps: int, + tracked_ranks: list[int] | None = None, + ) -> SteppableProfilerIF: + """Creates a steppable memory profiler based on the provided settings.""" + if tracked_ranks is None: + tracked_ranks = [] + + global_rank, world_size = ProfilerFactory._get_global_rank_and_world_size() + profiler = SteppableMemoryProfiler( + memory_snapshot_path=memory_snapshot_folder_path + / f"memory_snapshot_ranks_{world_size}_rank_{global_rank}.pkl", + num_wait_steps=num_wait_steps, + num_warmup_steps=num_warmup_steps, + num_active_steps=num_active_steps, + ) + if global_rank not in tracked_ranks: + num_steps = len(profiler) + return SteppableNoProfiler(num_steps=num_steps) + else: + return profiler + + @staticmethod + def _get_global_rank_and_world_size() -> tuple[int, int]: + global_rank = 0 + world_size = 1 + if torch.distributed.is_initialized(): + global_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + return global_rank, world_size diff --git a/src/modalities/utils/profilers/profilers.py b/src/modalities/utils/profilers/profilers.py new file mode 100644 index 000000000..40ffe191e --- /dev/null +++ b/src/modalities/utils/profilers/profilers.py @@ -0,0 +1,207 @@ +import pickle +from pathlib import Path + +import torch +from torch.profiler import profile, schedule + +from modalities.utils.logger_utils import get_logger + +logger = get_logger("modalities_profiler") + + +class SteppableProfilerIF: + def __enter__(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_value, traceback): + raise NotImplementedError + + def step(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class SteppableCombinedProfiler(SteppableProfilerIF): + def __init__(self, profilers: list[SteppableProfilerIF]): + self._profilers = profilers + self._entered = None + + def __enter__(self): + if self._entered is not None: + raise RuntimeError("ProfilerListContext entered multiple times without exiting") + self._entered = [] + for profiler_cm in self._profilers: + return_val = profiler_cm.__enter__() + if return_val is not None: + self._entered.append(return_val) + else: + self._entered.append(profiler_cm) + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self._entered is None: + raise RuntimeError("ProfilerListContext exited without being entered") + for profiler_cm in self._entered: + profiler_cm.__exit__(exc_type, exc_value, traceback) + self._entered = None + + def __len__(self): + max_len = max([len(p) for p in self._profilers]) + min_len = min([len(p) for p in self._profilers if not isinstance(p, SteppableNoProfiler)]) + if max_len != min_len: + logger.warning( + "SteppableCombinedProfiler has profilers of different step lengths." + f" Max steps: {max_len}, Min steps: {min_len}." + " The combined profiler will run for the maximum steps, and some profilers may be inactive or fail." + ) + return max_len + + def step(self): + if self._entered is None: + raise RuntimeError("ProfilerListContext.step() called outside of context manager") + for profiler_cm in self._entered: + profiler_cm.step() + + +class SteppableNoProfiler(SteppableProfilerIF): + def __init__(self, num_steps: int) -> None: + super().__init__() + self._num_steps = num_steps + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + return + + def __len__(self): + return self._num_steps + + def step(self): + return + + +class SteppableMemoryProfiler(SteppableProfilerIF): + MEMORY_SNAPSHOT_MAX_ENTRIES = 100_000 + + def __init__(self, memory_snapshot_path: Path, num_wait_steps: int, num_warmup_steps: int, num_active_steps: int): + self._memory_snapshot_path = memory_snapshot_path + self._curr_step = None + self._num_wait_steps = num_wait_steps + self._num_warmup_steps = num_warmup_steps + self._num_active_steps = num_active_steps + + def __enter__(self): + self._curr_step = 0 + # start recording memory history if there is no wait / warmup steps + if self._curr_step == self._num_wait_steps + self._num_warmup_steps and self._num_active_steps > 0: + torch.cuda.memory._record_memory_history(max_entries=SteppableMemoryProfiler.MEMORY_SNAPSHOT_MAX_ENTRIES) + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self._curr_step is None: + raise RuntimeError("SteppableMemoryProfilerContext exited without being entered") + if self._curr_step < self._num_wait_steps + self._num_warmup_steps + self._num_active_steps: + # if we exit before finishing all steps, dump the memory snapshot + raise RuntimeError("SteppableMemoryProfilerContext exited before finishing all steps") + return + + def __len__(self): + return self._num_wait_steps + self._num_warmup_steps + self._num_active_steps + + def step(self): + if self._curr_step is None: + raise RuntimeError("SteppableMemoryProfilerContext.step() called outside of context manager") + self._curr_step += 1 + if self._curr_step < self._num_wait_steps + self._num_warmup_steps: + return + elif self._curr_step == self._num_wait_steps + self._num_warmup_steps: + torch.cuda.memory._record_memory_history(max_entries=SteppableMemoryProfiler.MEMORY_SNAPSHOT_MAX_ENTRIES) + elif ( + self._curr_step == self._num_wait_steps + self._num_warmup_steps + self._num_active_steps + and self._num_active_steps > 0 + ): + with open(self._memory_snapshot_path, "wb") as output: + pickle.dump(torch.cuda.memory._snapshot(), output) + + +class SteppableKernelProfiler(SteppableProfilerIF): + def __init__( + self, + num_wait_steps: int, + num_warmup_steps: int, + num_active_steps: int, + profiler_activities: list[torch.profiler.ProfilerActivity], + record_shapes: bool, + profile_memory: bool, + with_flops: bool, + with_stack: bool, + with_modules: bool, + trace_output_path: Path, + summary_output_path: Path, + ) -> None: # TODO specify Callable type + super().__init__() + self._num_wait_steps = num_wait_steps + self._num_warmup_steps = num_warmup_steps + self._num_active_steps = num_active_steps + self._profiler_activities = profiler_activities + self._record_shapes = record_shapes + self._profile_memory = profile_memory + self._with_flops = with_flops + self._with_stack = with_stack + self._with_modules = with_modules + self._trace_output_path = trace_output_path + self._summary_output_path = summary_output_path + self._kernel_profiler = None + + def __enter__(self): + if self._kernel_profiler is not None: + raise RuntimeError("Context entered multiple times without exiting") + self._curr_step = 0 + self._kernel_profiler = profile( + activities=self._profiler_activities, + schedule=schedule(wait=self._num_wait_steps, warmup=self._num_warmup_steps, active=self._num_active_steps), + record_shapes=self._record_shapes, + profile_memory=self._profile_memory, + with_flops=self._with_flops, + with_stack=self._with_stack, + with_modules=self._with_modules, + ) + + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._export_profiling_results() + self._kernel_profiler = None + if self._curr_step is None: + raise RuntimeError("SteppableKernelProfiler exited without being entered") + if self._curr_step < self._num_wait_steps + self._num_warmup_steps + self._num_active_steps: + # if we exit before finishing all steps, dump the memory snapshot + raise RuntimeError("SteppableKernelProfiler exited before finishing all steps") + return + + def __len__(self): + return self._num_wait_steps + self._num_warmup_steps + self._num_active_steps + + def step(self): + if self._curr_step is None: + raise RuntimeError("SteppableKernelProfiler.step() called outside of context manager") + if self._kernel_profiler is None: + raise RuntimeError("SteppableKernelProfiler.step() called when profiler is not initialized") + self._curr_step += 1 + self._kernel_profiler.step() + + def _export_profiling_results(self) -> None: + # Export profiling results to specified output paths if the current rank is in profiled_ranks. + if self._kernel_profiler is None: + raise RuntimeError( + "SteppableKernelProfiler._export_profiling_results() called when profiler is not initialized" + ) + logger.info("Saving profiling results...") + self._kernel_profiler.export_chrome_trace(self._trace_output_path.as_posix()) + table = self._kernel_profiler.key_averages().table() + self._summary_output_path.parent.mkdir(parents=True, exist_ok=True) + with open(self._summary_output_path, "w", encoding="utf-8") as f: + f.write(table) From 02e8fddf7b18dfdd75320bbadd2c559291df7488 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 19 Dec 2025 12:40:17 +0100 Subject: [PATCH 07/14] feat: logger outputs now rank info --- src/modalities/utils/logger_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/modalities/utils/logger_utils.py b/src/modalities/utils/logger_utils.py index 47f9c7797..095e998d1 100644 --- a/src/modalities/utils/logger_utils.py +++ b/src/modalities/utils/logger_utils.py @@ -1,11 +1,18 @@ import logging +import torch + def get_logger(name: str = "main") -> logging.Logger: + rank_info = "" + + if torch.distributed.is_initialized(): + rank_info = f"[RANK {torch.distributed.get_rank()}] " + logger = logging.getLogger(name) if not logger.handlers: logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter("%(name)s - %(levelname)s - %(message)s")) + handler.setFormatter(logging.Formatter(f"{rank_info}%(name)s - %(levelname)s - %(message)s")) logger.addHandler(handler) return logger From 37b25d835ce5bb70f77cb07b5d1b18a106a739ae Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 19 Dec 2025 12:40:55 +0100 Subject: [PATCH 08/14] refactor: step information in profiling now part of the config instead of CMD args --- src/modalities/__main__.py | 38 ------------------- .../run_distributed_model_profiling.py | 9 ----- 2 files changed, 47 deletions(-) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 891d611e0..48de07f57 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -705,52 +705,14 @@ def profile(): required=True, help="Path to the experiment output directory.", ) -@click.option( - "--num_wait_steps", - type=int, - default=1, - show_default=True, - help="Number of wait steps to skip in profiling.", -) -@click.option( - "--num_warmup_steps", - type=int, - default=1, - show_default=True, - help="Number of warmup steps to skip in profiling. Already recording but dropping the data.", -) -@click.option( - "--num_measurement_steps", - type=int, - default=3, - show_default=True, - help="Number of steps to measure during profiling.", -) -@click.option( - "--profiled_ranks", - type=str, - default="0", - help="Comma-separated list of profiled ranks (must not have spaces), e.g. --profiled_ranks '2,4,8'", -) def CMD_entry_point_run_train_step_profiler( config_file_path: Path, experiment_root_path: Path, - num_wait_steps: int, - num_warmup_steps: int, - num_measurement_steps: int, - profiled_ranks: str, ): """Run train step profiler and write result to JSON if RANK=0.""" - profiled_ranks_list = [int(i) for i in profiled_ranks.split(",")] if profiled_ranks != "" else [0] - logger.info(f"Running distributed profiling on ranks {profiled_ranks_list}") - ModalitiesProfilerStarter.run_distributed( config_file_path=config_file_path, - num_measurement_steps=num_measurement_steps, - num_wait_steps=num_wait_steps, - num_warmup_steps=num_warmup_steps, experiment_root_path=experiment_root_path, - profiled_ranks=profiled_ranks_list, ) diff --git a/tutorials/profiling/scripts/distributed/run_distributed_model_profiling.py b/tutorials/profiling/scripts/distributed/run_distributed_model_profiling.py index 619b22c92..4ee2d7966 100644 --- a/tutorials/profiling/scripts/distributed/run_distributed_model_profiling.py +++ b/tutorials/profiling/scripts/distributed/run_distributed_model_profiling.py @@ -7,16 +7,7 @@ config_path = cwd / Path("../../configs/distributed_8B_model_profiling.yaml") experiment_root_path = cwd / Path("../../experiments/") - num_measurement_steps = 3 - num_wait_steps = 20 - num_warmup_steps = 20 - profiled_ranks = [0, 1] - ModalitiesProfilerStarter.run_distributed( config_file_path=config_path, - num_measurement_steps=num_measurement_steps, - num_wait_steps=num_wait_steps, - num_warmup_steps=num_warmup_steps, experiment_root_path=experiment_root_path, - profiled_ranks=profiled_ranks, ) From 52924ea4168f3a4252bc20d5c8ee32cb85457100 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 19 Dec 2025 12:43:29 +0100 Subject: [PATCH 09/14] refactor: added new profiling setup to the profiling tutorial's config --- .../distributed_8B_model_profiling.yaml | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tutorials/profiling/configs/distributed_8B_model_profiling.yaml b/tutorials/profiling/configs/distributed_8B_model_profiling.yaml index b8e996387..88ef2763d 100644 --- a/tutorials/profiling/configs/distributed_8B_model_profiling.yaml +++ b/tutorials/profiling/configs/distributed_8B_model_profiling.yaml @@ -6,7 +6,26 @@ settings: benchmark: sequence_length: 4096 vocab_size: 50304 - batch_size: 2 + batch_size: 1 + paths: + experiment_root_path: ${modalities_env:config_folder_path} + +profiler: + component_key: steppable_profiler + variant_key: kernel_tracing + config: + num_wait_steps: 5 + num_warmup_steps: 5 + num_active_steps: 3 + profiler_activities: [CPU, CUDA] + record_shapes: false + profile_memory: false + with_stack: false + with_flops: false + with_modules: false + tracked_ranks: [0, 1] + output_folder_path: ${settings.paths.experiment_root_path}/kernel_traces + steppable_component: component_key: steppable_component From 3cfa30526d6dc5292006a2a3e5bd128ccb637ff7 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sun, 21 Dec 2025 17:14:16 +0100 Subject: [PATCH 10/14] refactor: experiments_root_path now passed in from outside --- src/modalities/__main__.py | 22 +++++++++++++++------- src/modalities/main.py | 10 ++++++++-- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 48de07f57..9bf60324f 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -54,10 +54,10 @@ def main() -> None: help="Path to the YAML training config file.", ) @click.option( - "--test_comm", - is_flag=True, - default=False, - help="If set, run a communication test before training.", + "--experiments_root_path", + type=click_pathlib.Path(exists=True), + required=True, + help="Path to the root directory where experiment folders will be created.", ) @click.option( "--experiment_id", @@ -71,20 +71,28 @@ def main() -> None: default=None, help="Optional path to a folder where error logs will be written.", ) +@click.option( + "--test_comm", + is_flag=True, + default=False, + help="If set, run a communication test before training.", +) def CMD_entry_point_run_modalities( config_file_path: Path, - test_comm: bool = False, + experiments_root_path: Path, experiment_id: Optional[str] = None, error_log_folder: Optional[Path] = None, + test_comm: bool = False, ): """Entrypoint to run the model training. Args: config_file_path (Path): Path to the YAML training config file. - test_comm (bool): If set, run a communication test before training. + experiments_root_path (Path): Path to the root directory where experiment folders will be created. experiment_id (Optional[str]): Optional experiment ID to use for this run. If not provided it will be generated. Default is None. error_log_folder (Optional[Path]): Optional path to a folder where error logs will be written. + test_comm (bool): If set, run a communication test before training. """ def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str: @@ -104,7 +112,7 @@ def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str: run_communication_test() print_rank_0("Communication test succeeded.") - main_obj = Main(config_file_path, experiment_id=experiment_id) + main_obj = Main(config_file_path, experiments_root_path=experiments_root_path, experiment_id=experiment_id) components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel) main_obj.run(components) except Exception as e: diff --git a/src/modalities/main.py b/src/modalities/main.py index 1fc46fdab..ffb673df6 100644 --- a/src/modalities/main.py +++ b/src/modalities/main.py @@ -39,14 +39,19 @@ class Main: def __init__( self, config_path: Path, + experiments_root_path: Path, additional_resolver_funs: Optional[dict[str, Callable]] = None, experiment_id: Optional[str] = None, ) -> None: + self.experiments_root_path = experiments_root_path if experiment_id is None: experiment_id = get_synced_experiment_id_of_run(config_path) self.config_dict = load_app_config_dict( - config_file_path=config_path, experiment_id=experiment_id, additional_resolver_funs=additional_resolver_funs + config_file_path=config_path, + experiments_root_path=experiments_root_path, + experiment_id=experiment_id, + additional_resolver_funs=additional_resolver_funs, ) self.config_path = config_path @@ -109,7 +114,7 @@ def run(self, components: TrainingComponentsInstantiationModel): # In this case, we only allow the config file to be present in the experiment folder. # NOTE: For the future, these constraints might be relaxed, as some components might have to # store meta data in the experiment folder at instantiation time. - experiment_path = components.settings.paths.checkpoint_saving_path / components.settings.experiment_id + experiment_path = components.settings.paths.experiments_root_path / components.settings.experiment_id expected_config_file_path = experiment_path / self.config_path.name if experiment_path.is_dir(): present_files = list(experiment_path.iterdir()) @@ -180,6 +185,7 @@ def run(self, components: TrainingComponentsInstantiationModel): global_num_tokens_per_train_step=global_num_tokens_per_train_step, device_mesh=components.device_mesh, mfu_calculator=components.mfu_calculator, + profiler=components.profiler, ) # Evaluator From 361ddc5619293f3321092c8bed805b02d6ad002e Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sun, 21 Dec 2025 17:16:10 +0100 Subject: [PATCH 11/14] feat: profiling now available also in training loop --- src/modalities/config/config.py | 2 + src/modalities/config/instantiation_models.py | 13 +- src/modalities/trainer.py | 238 +++++++++--------- .../utils/profilers/profiler_configs.py | 1 - .../utils/profilers/profiler_factory.py | 2 - src/modalities/utils/profilers/profilers.py | 30 ++- 6 files changed, 153 insertions(+), 133 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 607d902bf..b2e9ff50d 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -505,6 +505,7 @@ class ParallelDegreeConfig(BaseModel): def load_app_config_dict( config_file_path: Path, + experiments_root_path: Path, experiment_id: Optional[str] = None, additional_resolver_funs: Optional[dict[str, Resolver]] = None, ) -> dict[str, YAMLValue]: @@ -537,6 +538,7 @@ def node_env_resolver_fun(var_name: str) -> int | None: modalities_env_kwargs: dict[str, Any] = { "config_file_path": config_file_path, "config_folder_path": config_file_path.parent, + "experiments_root_path": experiments_root_path, } if experiment_id is not None: modalities_env_kwargs["experiment_id"] = experiment_id diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 570c03178..7df862e1c 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -17,12 +17,14 @@ PydanticPipelineType, PydanticPytorchDeviceType, PydanticPytorchModuleType, + PydanticSteppableProfilerIFType, PydanticTextInferenceComponentType, PydanticTokenizerIFType, ) from modalities.config.utils import parse_torch_device from modalities.dataloader.dataset import Dataset from modalities.util import warn_rank_0 +from modalities.utils.profilers.profilers import SteppableNoProfiler class CudaEnvSettings(BaseModel): @@ -66,7 +68,7 @@ class TrainingProgress(BaseModel): class TrainingComponentsInstantiationModel(BaseModel): class Settings(BaseModel): class Paths(BaseModel): - checkpoint_saving_path: Path # Explicitly defined field + experiments_root_path: Path # Explicitly defined field class Config: extra = "allow" @@ -180,13 +182,14 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel evaluation_subscriber: PydanticMessageSubscriberIFType checkpoint_saving: PydanticCheckpointSavingIFType gradient_clipper: PydanticGradientClipperIFType - mfu_calculator: Optional[PydanticMFUCalculatorABCType] = None - scheduled_pipeline: Optional[PydanticPipelineType] = None - device_mesh: Optional[PydanticDeviceMeshIFType] = None + profiler: PydanticSteppableProfilerIFType = SteppableNoProfiler() + mfu_calculator: PydanticMFUCalculatorABCType | None = None + scheduled_pipeline: PydanticPipelineType | None = None + device_mesh: PydanticDeviceMeshIFType | None = None model_raw: PydanticPytorchModuleType @model_validator(mode="after") - def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationModel.Settings": + def _check_token_amount_in_dataset(self) -> "TrainingComponentsInstantiationModel": if ( len(self.train_dataset) * self.settings.step_profile.sequence_length < self.settings.training_target.num_target_tokens diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index ac00b41b8..cb1250a2c 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -23,6 +23,7 @@ from modalities.training.training_progress import TrainingProgress from modalities.util import TimeRecorder, print_rank_0 from modalities.utils.mfu import MFUCalculatorABC +from modalities.utils.profilers.profilers import SteppableProfilerIF class ThroughputAggregationKeys(Enum): @@ -44,7 +45,8 @@ def __init__( num_target_steps: int, num_target_tokens: int, gradient_clipper: GradientClipperIF, - mfu_calculator: Optional[MFUCalculatorABC] = None, + profiler: SteppableProfilerIF, + mfu_calculator: MFUCalculatorABC | None = None, ) -> None: """ Initializes the Trainer object. @@ -62,6 +64,7 @@ def __init__( num_target_steps (int): Number of target steps. num_target_tokens (int): Number of target tokens. gradient_clipper (GradientClipperIF): Gradient clipper. + profiler (SteppableProfilerIF): Profiler to profile the training loop. mfu_calculator (Optional[MFUCalculatorABC]): MFU calculator. Returns: @@ -85,6 +88,7 @@ def __init__( self.num_target_tokens = num_target_tokens self.global_num_seen_tokens = global_num_seen_tokens self.gradient_clipper = gradient_clipper + self.profiler = profiler self.mfu_calculator = mfu_calculator @staticmethod @@ -230,124 +234,128 @@ def train( num_steps_todo = self.num_target_steps - self.num_seen_train_steps num_batches_todo = num_steps_todo * self.gradient_acc_steps # Because we might resume training, we add the starting batch id of the data loader - for _, (micro_batch_id, batch) in zip(range(num_batches_todo), enumerate(train_loader)): - # Train single batch - ( - step_performed, - num_train_steps_done, - batch_loss, - gradient_norm_score, - ) = self._train_batch( - batch=batch, - model=model, - optimizer=optimizer, - scheduler=lr_scheduler, - loss_fun=loss_fun, - micro_batch_id=micro_batch_id, - scheduled_pipeline=scheduled_pipeline, - ) - training_progress.num_seen_steps_current_run = num_train_steps_done - training_progress.num_seen_tokens_current_run = self.global_num_tokens_per_train_step * num_train_steps_done - - # The batch_loss might be None if we use pipeline parallelism and are not the last stage. - if batch_loss is not None: - # Save the batch loss - cumulated_losses[0] += batch_loss.item() - # This works, because we always drop the last batch in case it has less samples than the batch size - cumulated_losses[-1] += 1 # number of local batches - - # gradient norm is already synced across all ranks - if gradient_norm_score is not None: - gradient_norm_scores.append(gradient_norm_score.item()) - - local_num_seen_samples += torch.tensor(len(batch)).to(device) - - self._publish_progress( - progress_publisher=self.progress_publisher, - num_train_steps_done=training_progress.num_seen_steps_total, - dataloader_tag=train_loader.dataloader_tag, - ) - # Check if model performance should be logged - if training_progress.num_seen_steps_total % training_log_interval_in_steps == 0 and step_performed: - forward_backward_time_recorder.stop() - forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device) - forward_backward_time_recorder.reset() - forward_backward_time_recorder.start() - - global_num_seen_samples = local_num_seen_samples * self.dp_degree - local_num_seen_samples = 0 - global_num_samples_per_second = global_num_seen_samples / forward_backward_time - - # TODO: insert reducer from outside so Trainer is independent of FSDP - # add the loss and gradient norm for the LAST batch - cumulated_losses[1] = batch_loss.item() if batch_loss is not None else 0.0 - - reduced_losses = Reducer.reduce( - tensor=cumulated_losses, - operation=dist.ReduceOp.SUM, - # 1.) summed batch loss / (num batches * (world size / dp_degree)) - # 2.) last batch loss / (world size / pp_degree) - post_processing_fun=lambda t: torch.stack( - [t[0] / t[-1], t[1] / dist.get_world_size() * self.pp_degree] - ), + with self.profiler as profiler_cm: + for _, (micro_batch_id, batch) in zip(range(num_batches_todo), enumerate(train_loader)): + # Train single batch + ( + step_performed, + num_train_steps_done, + batch_loss, + gradient_norm_score, + ) = self._train_batch( + batch=batch, + model=model, + optimizer=optimizer, + scheduler=lr_scheduler, + loss_fun=loss_fun, + micro_batch_id=micro_batch_id, + scheduled_pipeline=scheduled_pipeline, ) - - train_loss_avg, train_loss_last_batch = ( - reduced_losses[0], - reduced_losses[1], + training_progress.num_seen_steps_current_run = num_train_steps_done + training_progress.num_seen_tokens_current_run = ( + self.global_num_tokens_per_train_step * num_train_steps_done ) - losses = { - "train loss avg": ResultItem(train_loss_avg, decimal_places=2), - "train loss last": ResultItem(train_loss_last_batch, decimal_places=2), - } - - consumed_tokens = torch.tensor(training_progress.num_seen_tokens_total) - metrics = { - "consumed tokens": ResultItem(consumed_tokens, 0), - "grad norm avg": ResultItem(torch.mean(torch.Tensor(gradient_norm_scores)), 2), - "grad norm last": ResultItem(torch.tensor(gradient_norm_scores[-1]), 2), - } - gradient_norm_scores = [] - mfu_score = torch.tensor(-1.0) - if self.mfu_calculator is not None: - mfu_score = self.mfu_calculator.compute(num_samples_per_second=global_num_samples_per_second) - - # Collect peak memory depending on device type. On CPU we fall back to RSS (if available) or -1. - if device.type == "cuda": - peak_memory_MB = torch.cuda.max_memory_allocated(device) / 1024**2 # in MB - torch.cuda.reset_peak_memory_stats(device) - else: - # ru_maxrss is in kilobytes on Linux; convert to MB. Use -1.0 if resource unavailable. - try: - import resource # Standard lib (POSIX). Not available on some platforms. - - peak_memory_MB = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 - except Exception: - peak_memory_MB = -1.0 - - training_metrics = EvaluationResultBatch( - losses=losses, - metrics=metrics, - # TODO: hardcoded metric key - throughput_metrics={ - "train samples/s": ResultItem(global_num_samples_per_second, 1), - "train mfu (16-bit)": ResultItem(mfu_score, 2), - "lr mean": ResultItem(torch.tensor(lr_scheduler.get_last_lr()).mean()), - "peak memory rank 0 (MB)": ResultItem(torch.tensor(peak_memory_MB), 2), - }, - dataloader_tag=train_loader.dataloader_tag, + + # The batch_loss might be None if we use pipeline parallelism and are not the last stage. + if batch_loss is not None: + # Save the batch loss + cumulated_losses[0] += batch_loss.item() + # This works, because we always drop the last batch in case it has less samples than the batch size + cumulated_losses[-1] += 1 # number of local batches + + # gradient norm is already synced across all ranks + if gradient_norm_score is not None: + gradient_norm_scores.append(gradient_norm_score.item()) + + local_num_seen_samples += torch.tensor(len(batch)).to(device) + + self._publish_progress( + progress_publisher=self.progress_publisher, num_train_steps_done=training_progress.num_seen_steps_total, + dataloader_tag=train_loader.dataloader_tag, ) - print_rank_0(f"{datetime.now().isoformat(timespec='seconds')} | {training_metrics}") - self._publish_evaluation_result( - evaluation_result_publisher=self.evaluation_result_publisher, - evaluation_result=training_metrics, - ) - - cumulated_losses = self._reset_tracked_losses() - if step_performed: - evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) - checkpointing_callback(training_progress=training_progress) + # Check if model performance should be logged + if training_progress.num_seen_steps_total % training_log_interval_in_steps == 0 and step_performed: + forward_backward_time_recorder.stop() + forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device) + forward_backward_time_recorder.reset() + forward_backward_time_recorder.start() + + global_num_seen_samples = local_num_seen_samples * self.dp_degree + local_num_seen_samples = 0 + global_num_samples_per_second = global_num_seen_samples / forward_backward_time + + # TODO: insert reducer from outside so Trainer is independent of FSDP + # add the loss and gradient norm for the LAST batch + cumulated_losses[1] = batch_loss.item() if batch_loss is not None else 0.0 + + reduced_losses = Reducer.reduce( + tensor=cumulated_losses, + operation=dist.ReduceOp.SUM, + # 1.) summed batch loss / (num batches * (world size / dp_degree)) + # 2.) last batch loss / (world size / pp_degree) + post_processing_fun=lambda t: torch.stack( + [t[0] / t[-1], t[1] / dist.get_world_size() * self.pp_degree] + ), + ) + + train_loss_avg, train_loss_last_batch = ( + reduced_losses[0], + reduced_losses[1], + ) + losses = { + "train loss avg": ResultItem(train_loss_avg, decimal_places=2), + "train loss last": ResultItem(train_loss_last_batch, decimal_places=2), + } + + consumed_tokens = torch.tensor(training_progress.num_seen_tokens_total) + metrics = { + "consumed tokens": ResultItem(consumed_tokens, 0), + "grad norm avg": ResultItem(torch.mean(torch.Tensor(gradient_norm_scores)), 2), + "grad norm last": ResultItem(torch.tensor(gradient_norm_scores[-1]), 2), + } + gradient_norm_scores = [] + mfu_score = torch.tensor(-1.0) + if self.mfu_calculator is not None: + mfu_score = self.mfu_calculator.compute(num_samples_per_second=global_num_samples_per_second) + + # Collect peak memory depending on device type. On CPU we fall back to RSS (if available) or -1. + if device.type == "cuda": + peak_memory_MB = torch.cuda.max_memory_allocated(device) / 1024**2 # in MB + torch.cuda.reset_peak_memory_stats(device) + else: + # ru_maxrss is in kilobytes on Linux; convert to MB. Use -1.0 if resource unavailable. + try: + import resource # Standard lib (POSIX). Not available on some platforms. + + peak_memory_MB = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 + except Exception: + peak_memory_MB = -1.0 + + training_metrics = EvaluationResultBatch( + losses=losses, + metrics=metrics, + # TODO: hardcoded metric key + throughput_metrics={ + "train samples/s": ResultItem(global_num_samples_per_second, 1), + "train mfu (16-bit)": ResultItem(mfu_score, 2), + "lr mean": ResultItem(torch.tensor(lr_scheduler.get_last_lr()).mean()), + "peak memory rank 0 (MB)": ResultItem(torch.tensor(peak_memory_MB), 2), + }, + dataloader_tag=train_loader.dataloader_tag, + num_train_steps_done=training_progress.num_seen_steps_total, + ) + print_rank_0(f"{datetime.now().isoformat(timespec='seconds')} | {training_metrics}") + self._publish_evaluation_result( + evaluation_result_publisher=self.evaluation_result_publisher, + evaluation_result=training_metrics, + ) + + cumulated_losses = self._reset_tracked_losses() + if step_performed: + evaluation_callback(num_train_steps_done=training_progress.num_seen_steps_total) + checkpointing_callback(training_progress=training_progress) + profiler_cm.step() def _reset_tracked_losses(self): # Initializes and returns a tensor representing the cumulated loss and gradient norm. diff --git a/src/modalities/utils/profilers/profiler_configs.py b/src/modalities/utils/profilers/profiler_configs.py index 2a056a81c..dc58f3982 100644 --- a/src/modalities/utils/profilers/profiler_configs.py +++ b/src/modalities/utils/profilers/profiler_configs.py @@ -19,7 +19,6 @@ class SteppableKernelProfilerConfig(BaseModel): num_active_steps: int profiler_activities: list[ModalitiesProfilerActivity] record_shapes: bool - profile_memory: bool with_flops: bool with_stack: bool with_modules: bool diff --git a/src/modalities/utils/profilers/profiler_factory.py b/src/modalities/utils/profilers/profiler_factory.py index 0dfe335fc..113e0e9a2 100644 --- a/src/modalities/utils/profilers/profiler_factory.py +++ b/src/modalities/utils/profilers/profiler_factory.py @@ -21,7 +21,6 @@ def create_steppable_kernel_profiler( num_active_steps: int, profiler_activities: list[ModalitiesProfilerActivity], record_shapes: bool, - profile_memory: bool, with_flops: bool, with_stack: bool, with_modules: bool, @@ -50,7 +49,6 @@ def create_steppable_kernel_profiler( num_active_steps=num_active_steps, profiler_activities=profiler_activities_converted, record_shapes=record_shapes, - profile_memory=profile_memory, with_flops=with_flops, with_stack=with_stack, with_modules=with_modules, diff --git a/src/modalities/utils/profilers/profilers.py b/src/modalities/utils/profilers/profilers.py index 40ffe191e..5b6828223 100644 --- a/src/modalities/utils/profilers/profilers.py +++ b/src/modalities/utils/profilers/profilers.py @@ -66,7 +66,7 @@ def step(self): class SteppableNoProfiler(SteppableProfilerIF): - def __init__(self, num_steps: int) -> None: + def __init__(self, num_steps: int = -1) -> None: super().__init__() self._num_steps = num_steps @@ -123,6 +123,7 @@ def step(self): self._curr_step == self._num_wait_steps + self._num_warmup_steps + self._num_active_steps and self._num_active_steps > 0 ): + self._memory_snapshot_path.parent.mkdir(parents=True, exist_ok=True) with open(self._memory_snapshot_path, "wb") as output: pickle.dump(torch.cuda.memory._snapshot(), output) @@ -135,26 +136,27 @@ def __init__( num_active_steps: int, profiler_activities: list[torch.profiler.ProfilerActivity], record_shapes: bool, - profile_memory: bool, with_flops: bool, with_stack: bool, with_modules: bool, trace_output_path: Path, summary_output_path: Path, - ) -> None: # TODO specify Callable type + ) -> None: super().__init__() self._num_wait_steps = num_wait_steps self._num_warmup_steps = num_warmup_steps self._num_active_steps = num_active_steps self._profiler_activities = profiler_activities self._record_shapes = record_shapes - self._profile_memory = profile_memory self._with_flops = with_flops self._with_stack = with_stack self._with_modules = with_modules self._trace_output_path = trace_output_path self._summary_output_path = summary_output_path self._kernel_profiler = None + self._exported_already = False + + self._kernel_profiler_cm = None def __enter__(self): if self._kernel_profiler is not None: @@ -164,23 +166,24 @@ def __enter__(self): activities=self._profiler_activities, schedule=schedule(wait=self._num_wait_steps, warmup=self._num_warmup_steps, active=self._num_active_steps), record_shapes=self._record_shapes, - profile_memory=self._profile_memory, with_flops=self._with_flops, with_stack=self._with_stack, with_modules=self._with_modules, ) - + self._exported_already = False + self._kernel_profiler_cm = self._kernel_profiler.__enter__() return self def __exit__(self, exc_type, exc_value, traceback): - self._export_profiling_results() + if not self._exported_already: + self._export_profiling_results() + self._kernel_profiler_cm.__exit__(exc_type, exc_value, traceback) self._kernel_profiler = None if self._curr_step is None: raise RuntimeError("SteppableKernelProfiler exited without being entered") if self._curr_step < self._num_wait_steps + self._num_warmup_steps + self._num_active_steps: # if we exit before finishing all steps, dump the memory snapshot raise RuntimeError("SteppableKernelProfiler exited before finishing all steps") - return def __len__(self): return self._num_wait_steps + self._num_warmup_steps + self._num_active_steps @@ -188,10 +191,16 @@ def __len__(self): def step(self): if self._curr_step is None: raise RuntimeError("SteppableKernelProfiler.step() called outside of context manager") - if self._kernel_profiler is None: + elif self._kernel_profiler is None: raise RuntimeError("SteppableKernelProfiler.step() called when profiler is not initialized") + self._curr_step += 1 - self._kernel_profiler.step() + if self._curr_step <= self._num_wait_steps + self._num_warmup_steps + self._num_active_steps: + self._kernel_profiler.step() + elif not self._exported_already: + self._export_profiling_results() + self._kernel_profiler_cm.__exit__(None, None, None) + self._exported_already = True def _export_profiling_results(self) -> None: # Export profiling results to specified output paths if the current rank is in profiled_ranks. @@ -200,6 +209,7 @@ def _export_profiling_results(self) -> None: "SteppableKernelProfiler._export_profiling_results() called when profiler is not initialized" ) logger.info("Saving profiling results...") + self._trace_output_path.parent.mkdir(parents=True, exist_ok=True) self._kernel_profiler.export_chrome_trace(self._trace_output_path.as_posix()) table = self._kernel_profiler.key_averages().table() self._summary_output_path.parent.mkdir(parents=True, exist_ok=True) From 68dd9d2b63ea7226dbce795fdfedc4415e4ae157 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Mon, 22 Dec 2025 11:48:58 +0100 Subject: [PATCH 12/14] feat: added memory profiling to kernel profiler --- src/modalities/utils/profilers/profiler_configs.py | 1 + src/modalities/utils/profilers/profiler_factory.py | 2 ++ src/modalities/utils/profilers/profilers.py | 3 +++ 3 files changed, 6 insertions(+) diff --git a/src/modalities/utils/profilers/profiler_configs.py b/src/modalities/utils/profilers/profiler_configs.py index dc58f3982..71fb029cf 100644 --- a/src/modalities/utils/profilers/profiler_configs.py +++ b/src/modalities/utils/profilers/profiler_configs.py @@ -18,6 +18,7 @@ class SteppableKernelProfilerConfig(BaseModel): num_warmup_steps: int num_active_steps: int profiler_activities: list[ModalitiesProfilerActivity] + profile_memory: bool record_shapes: bool with_flops: bool with_stack: bool diff --git a/src/modalities/utils/profilers/profiler_factory.py b/src/modalities/utils/profilers/profiler_factory.py index 113e0e9a2..538cb5f9a 100644 --- a/src/modalities/utils/profilers/profiler_factory.py +++ b/src/modalities/utils/profilers/profiler_factory.py @@ -20,6 +20,7 @@ def create_steppable_kernel_profiler( num_warmup_steps: int, num_active_steps: int, profiler_activities: list[ModalitiesProfilerActivity], + profile_memory: bool, record_shapes: bool, with_flops: bool, with_stack: bool, @@ -48,6 +49,7 @@ def create_steppable_kernel_profiler( num_warmup_steps=num_warmup_steps, num_active_steps=num_active_steps, profiler_activities=profiler_activities_converted, + profile_memory=profile_memory, record_shapes=record_shapes, with_flops=with_flops, with_stack=with_stack, diff --git a/src/modalities/utils/profilers/profilers.py b/src/modalities/utils/profilers/profilers.py index 5b6828223..961646409 100644 --- a/src/modalities/utils/profilers/profilers.py +++ b/src/modalities/utils/profilers/profilers.py @@ -135,6 +135,7 @@ def __init__( num_warmup_steps: int, num_active_steps: int, profiler_activities: list[torch.profiler.ProfilerActivity], + profile_memory: bool, record_shapes: bool, with_flops: bool, with_stack: bool, @@ -147,6 +148,7 @@ def __init__( self._num_warmup_steps = num_warmup_steps self._num_active_steps = num_active_steps self._profiler_activities = profiler_activities + self._profile_memory = profile_memory self._record_shapes = record_shapes self._with_flops = with_flops self._with_stack = with_stack @@ -164,6 +166,7 @@ def __enter__(self): self._curr_step = 0 self._kernel_profiler = profile( activities=self._profiler_activities, + profile_memory=self._profile_memory, schedule=schedule(wait=self._num_wait_steps, warmup=self._num_warmup_steps, active=self._num_active_steps), record_shapes=self._record_shapes, with_flops=self._with_flops, From e4fe4b0f7bd5d2a00a24d5a191ca7c7a079e6700 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Mon, 29 Dec 2025 12:36:32 +0100 Subject: [PATCH 13/14] refactor: added experiments_root_path to warmstart API and improved error handling --- src/modalities/__main__.py | 85 ++++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 9bf60324f..bb29ce2fe 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -95,16 +95,6 @@ def CMD_entry_point_run_modalities( test_comm (bool): If set, run a communication test before training. """ - def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str: - # Format an exception into a structured JSON string with error message, type, and stack trace. - error = { - "error": str(e), - "type": type(e).__name__, - "stacktrace": traceback.format_exception(type(e), e, e.__traceback__), - } - - return json.dumps({"environment": environment, "error": error}, indent=2) - try: with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): if test_comm: @@ -116,24 +106,16 @@ def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str: components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel) main_obj.run(components) except Exception as e: - if error_log_folder is not None: - environment = { - "rank": int(os.environ["RANK"] if "RANK" in os.environ else -1), - "local_rank": int(os.environ["LOCAL_RANK"] if "LOCAL_RANK" in os.environ else -1), - "world_size": int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else -1), - "hostname": socket.gethostname(), - } - error_log_folder = ( - error_log_folder / f"error_logs_{environment['hostname']}_{environment['local_rank']}.log" - ) - error_log_folder.parent.mkdir(parents=True, exist_ok=True) - with open(error_log_folder, "w", encoding="utf-8") as f: - f.write(_format_exception_as_json(e, environment)) - - raise RuntimeError(f"An error occurred while running the training: {e}. ") from e + _exception_handling(e, error_log_folder) @main.command(name="warmstart") +@click.option( + "--experiments_root_path", + type=click_pathlib.Path(exists=True), + required=True, + help="Path to the root directory where experiment folders will be created.", +) @click.option( "--config_file_path", type=click_pathlib.Path(exists=True), @@ -146,10 +128,22 @@ def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str: required=True, help="Path to the file containing the model and optimizer checkpoint paths from the last successful checkpoint.", ) -def CMD_entry_point_warmstart_modalities(config_file_path: Path, last_checkpoint_info_file_path: Path): +@click.option( + "--error_log_folder", + type=click_pathlib.Path(), + default=None, + help="Optional path to a folder where error logs will be written.", +) +def CMD_entry_point_warmstart_modalities( + experiments_root_path: Path, + config_file_path: Path, + last_checkpoint_info_file_path: Path, + error_log_folder: Optional[Path] = None, +): """Entrypoint to run the model warmstart. Args: + experiments_root_path (Path): Path to the root directory where experiment folders will be created. config_file_path (Path): Path to the YAML warmstart config file. last_checkpoint_info_file_path (Path): Path to the file containing the model and optimizer checkpoint paths from the last successful checkpoint. @@ -167,10 +161,15 @@ def get_last_checkpoint_resolver_fun(var_name: str, last_checkpoint_info_file_pa get_last_checkpoint_resolver_fun, last_checkpoint_info_file_path=last_checkpoint_info_file_path ) } - with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): - main_obj = Main(config_file_path, additional_resolver_funs=resolver_funs) - components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel) - main_obj.run(components) + try: + with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + main_obj = Main( + config_file_path, experiments_root_path=experiments_root_path, additional_resolver_funs=resolver_funs + ) + components = main_obj.build_components(components_model_type=TrainingComponentsInstantiationModel) + main_obj.run(components) + except Exception as e: + _exception_handling(e, error_log_folder) @main.command(name="generate_text") @@ -724,5 +723,31 @@ def CMD_entry_point_run_train_step_profiler( ) +def _format_exception_as_json(e: Exception, environment: dict[str, Any]) -> str: + # Format an exception into a structured JSON string with error message, type, and stack trace. + error = { + "error": str(e), + "type": type(e).__name__, + "stacktrace": traceback.format_exception(type(e), e, e.__traceback__), + } + return json.dumps({"environment": environment, "error": error}, indent=2) + + +def _exception_handling(e: Exception, error_log_folder: Path | None): + if error_log_folder is not None: + environment = { + "rank": int(os.environ["RANK"] if "RANK" in os.environ else -1), + "local_rank": int(os.environ["LOCAL_RANK"] if "LOCAL_RANK" in os.environ else -1), + "world_size": int(os.environ["WORLD_SIZE"] if "WORLD_SIZE" in os.environ else -1), + "hostname": socket.gethostname(), + } + error_log_folder = error_log_folder / f"error_logs_{environment['hostname']}_{environment['local_rank']}.log" + error_log_folder.parent.mkdir(parents=True, exist_ok=True) + with open(error_log_folder, "w", encoding="utf-8") as f: + f.write(_format_exception_as_json(e, environment)) + + raise RuntimeError(f"An error occurred while running the training: {e}. ") from e + + if __name__ == "__main__": main() From fbab937b2f62791e873866c9d22613358becd327 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Mon, 29 Dec 2025 12:39:26 +0100 Subject: [PATCH 14/14] refactor: refactored wamstart tutorial scripts --- src/modalities/config/config.py | 9 ++++++--- .../warmstart/configs/pre_training_config.yaml | 5 +++-- .../warmstart/configs/warmstart_config.yaml | 5 +++-- .../scripts/check_checkpoint_consistency.py | 18 ++++++++++++------ .../scripts/pre_train_and_warmstart.sh | 10 +++++----- 5 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index b2e9ff50d..311c6bbbd 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -505,14 +505,16 @@ class ParallelDegreeConfig(BaseModel): def load_app_config_dict( config_file_path: Path, - experiments_root_path: Path, - experiment_id: Optional[str] = None, + experiments_root_path: Path | None = None, + experiment_id: str | None = None, additional_resolver_funs: Optional[dict[str, Resolver]] = None, ) -> dict[str, YAMLValue]: """Load the application configuration from the given YAML file. Args: config_file_path (Path): YAML config file. + experiments_root_path: (Path, optional): The path to the experiments root directory. + Defaults to None. experiment_id (str, optional): The experiment_id of the current run. additional_resolver_funs (dict[str, Resolver], optional): Additional resolver functions. @@ -538,8 +540,9 @@ def node_env_resolver_fun(var_name: str) -> int | None: modalities_env_kwargs: dict[str, Any] = { "config_file_path": config_file_path, "config_folder_path": config_file_path.parent, - "experiments_root_path": experiments_root_path, } + if experiments_root_path is not None: + modalities_env_kwargs["experiments_root_path"] = experiments_root_path if experiment_id is not None: modalities_env_kwargs["experiment_id"] = experiment_id OmegaConf.register_new_resolver( diff --git a/tutorials/warmstart/configs/pre_training_config.yaml b/tutorials/warmstart/configs/pre_training_config.yaml index 56ba18c8f..bff64c7ac 100644 --- a/tutorials/warmstart/configs/pre_training_config.yaml +++ b/tutorials/warmstart/configs/pre_training_config.yaml @@ -10,7 +10,8 @@ settings: global_rank: ${cuda_env:RANK} world_size: ${cuda_env:WORLD_SIZE} paths: - checkpoint_saving_path: ../data/checkpoints + experiments_root_path: ${modalities_env:experiments_root_path} + experiment_folder_path: ${settings.paths.experiments_root_path}/${settings.experiment_id} train_dataset_path: ../data/mem_map/redpajama_v2_samples_512_train.pbin intervals: training_log_interval_in_steps: 1 @@ -100,7 +101,7 @@ checkpoint_saving: component_key: checkpoint_saving_execution variant_key: dcp config: - checkpoint_path: ${settings.paths.checkpoint_saving_path} + checkpoint_path: ${settings.paths.experiment_folder_path}/checkpoints global_rank: ${settings.cuda_env.global_rank} experiment_id: ${settings.experiment_id} diff --git a/tutorials/warmstart/configs/warmstart_config.yaml b/tutorials/warmstart/configs/warmstart_config.yaml index a12286fef..b3e467741 100644 --- a/tutorials/warmstart/configs/warmstart_config.yaml +++ b/tutorials/warmstart/configs/warmstart_config.yaml @@ -10,7 +10,8 @@ settings: global_rank: ${cuda_env:RANK} world_size: ${cuda_env:WORLD_SIZE} paths: - checkpoint_saving_path: ../data/checkpoints + experiments_root_path: ${modalities_env:experiments_root_path} + experiment_folder_path: ${settings.paths.experiments_root_path}/${settings.experiment_id} train_dataset_path: ../data/mem_map/redpajama_v2_samples_512_train.pbin intervals: training_log_interval_in_steps: 1 @@ -124,7 +125,7 @@ checkpoint_saving: component_key: checkpoint_saving_execution variant_key: dcp config: - checkpoint_path: ${settings.paths.checkpoint_saving_path} + checkpoint_path: ${settings.paths.experiment_folder_path}/checkpoints global_rank: ${settings.cuda_env.global_rank} experiment_id: ${settings.experiment_id} diff --git a/tutorials/warmstart/scripts/check_checkpoint_consistency.py b/tutorials/warmstart/scripts/check_checkpoint_consistency.py index c7f216823..84c62e6d7 100644 --- a/tutorials/warmstart/scripts/check_checkpoint_consistency.py +++ b/tutorials/warmstart/scripts/check_checkpoint_consistency.py @@ -10,19 +10,25 @@ def _get_checkpoint_file_name_without_eid(checkpoint_file_name: str) -> str: def test_checkpoint_files_exist(checkpoint_folder_path: list[Path], expected_checkpoint_names: list[str]): - # Check if all the checkpoint files exist and have the correct names - checkpoint_paths = glob.glob(str(checkpoint_folder_path / "**/*"), recursive=True) + for expected_checkpoint_name in expected_checkpoint_names: + # Check if all the checkpoint files exist and have the correct names + checkpoint_paths = glob.glob( + str(checkpoint_folder_path / f"**/checkpoints/**/*{expected_checkpoint_name}/*"), + recursive=True, + include_hidden=True, + ) + checkpoint_files = [p for p in checkpoint_paths if os.path.isfile(p)] - assert len(checkpoint_paths) == 17, "ERROR! Expected 6 checkpoint files." - - assert len([p for p in checkpoint_paths if p.endswith(".distcp")]), "ERROR! Expected 6 checkpoint files." + assert len(checkpoint_files) == 3, f"ERROR! Expected 3 checkpoint files. Got {len(checkpoint_files)}." + num_checkpoint_files = len([p for p in checkpoint_files if p.endswith(".distcp")]) + assert num_checkpoint_files == 2, f"ERROR! Expected 2 checkpoint files. Got {num_checkpoint_files}." if __name__ == "__main__": current_file_path = Path(__file__).resolve() os.chdir(current_file_path.parent) - checkpoint_folder_path = Path("../data/checkpoints") + checkpoint_folder_path = Path("../experiments") expected_checkpoint_folder_names = [ # pretrain checkpoint diff --git a/tutorials/warmstart/scripts/pre_train_and_warmstart.sh b/tutorials/warmstart/scripts/pre_train_and_warmstart.sh index 0beba7089..3f88ae275 100644 --- a/tutorials/warmstart/scripts/pre_train_and_warmstart.sh +++ b/tutorials/warmstart/scripts/pre_train_and_warmstart.sh @@ -31,23 +31,23 @@ echo "> run warmstart example on CUDA_VISIBLE_DEVICES="$CUDA_VISIBLE_DEVICES cd "$(dirname "$0")" rm -rf ../data/ +rm -rf ../experiments # run preprocessing modalities data create_raw_index --index_path ../data/mem_map/redpajama_v2_samples_512_train.idx ../../getting_started/data/raw/redpajama_v2_samples_512_train.jsonl modalities data pack_encoded_data ../configs/tokenization_config_train.yaml +mkdir -p ../experiments # run pretraining -CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 2 $(which modalities) run --config_file_path ../configs/pre_training_config.yaml +CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES torchrun --rdzv-endpoint localhost:29456 --nnodes 1 --nproc_per_node 2 $(which modalities) run --config_file_path ../configs/pre_training_config.yaml --experiments_root_path ../experiments # run warmstart -checkpoint_path=$(find ../data/checkpoints -name "last_checkpoint_info.json" -exec realpath {} \;) -CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 2 $(which modalities) warmstart --config_file_path ../configs/warmstart_config.yaml --last_checkpoint_info_file_path $checkpoint_path +checkpoint_path=$(find ../experiments -name "last_checkpoint_info.json" -exec realpath {} \;) +CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES torchrun --rdzv-endpoint localhost:29457 --nnodes 1 --nproc_per_node 2 $(which modalities) warmstart --config_file_path ../configs/warmstart_config.yaml --experiments_root_path ../experiments --last_checkpoint_info_file_path $checkpoint_path # add some consistency checks python check_checkpoint_consistency.py -rm -rf ../data/ - echo "Finished warmstart example"