From 5f9f50e228584e6b355998ac03c4d33873a8b0ca Mon Sep 17 00:00:00 2001 From: rrutmann Date: Mon, 8 Dec 2025 14:18:31 +0000 Subject: [PATCH 01/10] fix: Initialize different weights across TP ranks --- src/modalities/models/gpt2/gpt2_model.py | 11 +++++++++++ src/modalities/models/model_factory.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 0a846b38a..66e19376f 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -7,8 +7,10 @@ import torch import torch.nn as nn from pydantic import BaseModel, Field, model_validator, validator +from torch.distributed.device_mesh import DeviceMesh from modalities.config.lookup_enum import LookupEnum +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType from modalities.config.utils import convert_base_model_config_to_dict from modalities.models.components.layer_norms import ( LayerNormConfig, @@ -17,6 +19,7 @@ RMSLayerNormConfig, ) from modalities.models.model import ActivationType, NNModel, SwiGLU +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method from modalities.util import parse_enum_by_name try: @@ -367,6 +370,7 @@ class GPT2LLMConfig(BaseModel): use_weight_tying: bool seed: Optional[int] = None enforce_swiglu_hidden_dim_multiple_of: int = 256 + device_mesh: Optional[PydanticDeviceMeshIFType] = None @model_validator(mode="after") def check_divisibility(self) -> "GPT2LLMConfig": @@ -834,6 +838,7 @@ def __init__( use_weight_tying: bool, seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, + device_mesh: DeviceMesh | None = None, ): """ Initializes the GPT2LLM object. @@ -862,12 +867,18 @@ def __init__( enforce_swiglu_hidden_dim_multiple_of (int): Enforces the hidden dimension in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256. + device_mesh (DeviceMesh | None): The device mesh for parallelism. Defaults to None. """ weight_decay_groups = { "linear": [".attn", ".mlp", ".lm_head.weight"], "embedding": [".wte", ".wpe"], "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], } + # Set different random seed for each TP rank to ensure diversity + if seed is not None and has_parallelism_method( + device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP + ): + seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP) super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) self.sample_key = sample_key self.prediction_key = prediction_key diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 3acb17f95..c4c953eaf 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -578,6 +578,7 @@ def get_gpt2_model( use_meta_device: Optional[bool] = False, seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, + device_mesh: DeviceMesh | None = None, ) -> GPT2LLM: config = dict( sample_key=sample_key, @@ -601,6 +602,7 @@ def get_gpt2_model( seed=seed, use_weight_tying=use_weight_tying, enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of, + device_mesh=device_mesh, ) if use_meta_device and use_weight_tying: raise ValueError( From 8c8c5abb716e86bfba86429462152944a890864a Mon Sep 17 00:00:00 2001 From: rrutmann Date: Tue, 9 Dec 2025 09:36:00 +0000 Subject: [PATCH 02/10] feat: Consider pp rank for model seed --- src/modalities/models/gpt2/gpt2_model.py | 38 ++++++++++++++++++++---- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 66e19376f..dd5dbd3ec 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -19,7 +19,12 @@ RMSLayerNormConfig, ) from modalities.models.model import ActivationType, NNModel, SwiGLU -from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method +from modalities.running_env.fsdp.device_mesh import ( + ParallelismDegrees, + get_parallel_degree, + get_parallel_rank, + has_parallelism_method, +) from modalities.util import parse_enum_by_name try: @@ -874,11 +879,9 @@ def __init__( "embedding": [".wte", ".wpe"], "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], } - # Set different random seed for each TP rank to ensure diversity - if seed is not None and has_parallelism_method( - device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP - ): - seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP) + # Set different random seed for each TP and PP rank to ensure diversity + if seed is not None and device_mesh is not None: + seed = _offset_seed_by_parallel_ranks(seed=seed, device_mesh=device_mesh) super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) self.sample_key = sample_key self.prediction_key = prediction_key @@ -1069,3 +1072,26 @@ def manual_scaled_dot_product_attention( attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value + + +def _offset_seed_by_parallel_ranks(seed: int, device_mesh: DeviceMesh) -> int: + """ + Return a seed shifted by the TP/PP ranks so each TP/PP pair produces a distinct value. + """ + tp_rank = None + pp_rank = None + pp_degree = 1 + + if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP): + tp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP) + if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP): + pp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) + pp_degree = get_parallel_degree(device_mesh=device_mesh, parallelism_methods=[ParallelismDegrees.PP]) + + if tp_rank is not None and pp_rank is not None: + return seed + tp_rank * pp_degree + pp_rank + if tp_rank is not None: + return seed + tp_rank + if pp_rank is not None: + return seed + pp_rank + return seed From ab3daa01a02adff9e20d2aed1b56b28845b1ec1c Mon Sep 17 00:00:00 2001 From: rrutmann Date: Wed, 10 Dec 2025 09:43:15 +0000 Subject: [PATCH 03/10] fix: Only consider PP rank for seeding --- src/modalities/models/gpt2/gpt2_model.py | 38 ++++-------------------- 1 file changed, 6 insertions(+), 32 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index dd5dbd3ec..c8f82ecf6 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -19,12 +19,7 @@ RMSLayerNormConfig, ) from modalities.models.model import ActivationType, NNModel, SwiGLU -from modalities.running_env.fsdp.device_mesh import ( - ParallelismDegrees, - get_parallel_degree, - get_parallel_rank, - has_parallelism_method, -) +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method from modalities.util import parse_enum_by_name try: @@ -879,9 +874,11 @@ def __init__( "embedding": [".wte", ".wpe"], "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], } - # Set different random seed for each TP and PP rank to ensure diversity - if seed is not None and device_mesh is not None: - seed = _offset_seed_by_parallel_ranks(seed=seed, device_mesh=device_mesh) + # Set different random seed for each PP rank to ensure diversity + if seed is not None and has_parallelism_method( + device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP + ): + seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) self.sample_key = sample_key self.prediction_key = prediction_key @@ -1072,26 +1069,3 @@ def manual_scaled_dot_product_attention( attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value - - -def _offset_seed_by_parallel_ranks(seed: int, device_mesh: DeviceMesh) -> int: - """ - Return a seed shifted by the TP/PP ranks so each TP/PP pair produces a distinct value. - """ - tp_rank = None - pp_rank = None - pp_degree = 1 - - if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP): - tp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP) - if has_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP): - pp_rank = get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) - pp_degree = get_parallel_degree(device_mesh=device_mesh, parallelism_methods=[ParallelismDegrees.PP]) - - if tp_rank is not None and pp_rank is not None: - return seed + tp_rank * pp_degree + pp_rank - if tp_rank is not None: - return seed + tp_rank - if pp_rank is not None: - return seed + pp_rank - return seed From 62a1743dcf6561f83c6ff5f545aa79a23bbfd2b0 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 12 Dec 2025 17:21:56 +0000 Subject: [PATCH 04/10] test: Add test for different parameters on tp/pp ranks --- .../test_parallel_seed_initialization.py | 169 ++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 tests/fsdp2_parallelization/test_parallel_seed_initialization.py diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py new file mode 100644 index 000000000..58f0d10c5 --- /dev/null +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -0,0 +1,169 @@ +import logging +import multiprocessing as py_mp +import os +import re +import traceback +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import yaml +from pydantic import BaseModel + +from modalities.__main__ import Main +from modalities.batch import EvaluationResultBatch +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType +from modalities.logging_broker.messages import Message +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_device_mesh, get_parallel_rank +from tests.end2end_tests.custom_components import MultiProcessingCudaEnv +from tests.utility import monitor_child_processes + +working_dir = Path(os.path.dirname(__file__)) +tmp_folder = working_dir / "../tmp/fsdp2_warmstart_pp_tp" +working_dir = working_dir / "configs" + + +@pytest.mark.skipif( + torch.cuda.device_count() < 8, + reason="This e2e test requires 8 GPUs.", +) +class TestParallelSeedInitialization: + WORLD_SIZE = 8 + RDVZ_PORT = 24574 + + def test_parameters_follow_parallelism(self, tmp_path: Path): + manager = py_mp.Manager() + error_queue = manager.Queue() + proc_ctx = mp.spawn( + self._seed_distribution_impl_wrapper, + args=(self.WORLD_SIZE, tmp_path, error_queue), + nprocs=self.WORLD_SIZE, + join=False, + ) + monitor_child_processes(manager, error_queue, proc_ctx) + + def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_path: Path, error_queue: Any): + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=TestParallelSeedInitialization.RDVZ_PORT, + ): + try: + self._seed_distribution_impl(world_size=world_size, tmp_path=tmp_path) + except Exception as exc: + tb = traceback.format_exc() + logging.error(f"Process {process_id} (seed distribution test) encountered an error:\n{exc}") + logging.error(tb) + try: + error_queue.put((process_id, tb)) + except Exception: + logging.error("Failed to put exception info into error queue (seed distribution test).") + os._exit(1) + + def _seed_distribution_impl(self, world_size: int, tmp_path: Path): + device_mesh = get_device_mesh( + device_type="cuda", + data_parallel_replicate_degree=2, + data_parallel_shard_degree=1, + tensor_parallel_degree=2, + pipeline_parallel_degree=2, + context_parallel_degree=1, + enable_loss_parallel=False, + world_size=world_size, + ) + + # initialize components + class ComponentsInstantiationModel(BaseModel): + fsdp_model: PydanticFSDP2ModuleType + device_mesh: PydanticDeviceMeshIFType + + config_file_path = self._get_tmp_sharding_config_path(dp_degree=2, tp_degree=2, pp_degree=2, tmp_path=tmp_path) + main_obj = Main(config_file_path) + components = main_obj.build_components(components_model_type=ComponentsInstantiationModel) + model = components.fsdp_model + device_mesh = components.device_mesh + # get first transformer block's MLP weight parameter shards + block_key = next(iter(model.transformer.h.keys())) + block = model.transformer.h[block_key] + payload = { + "tensor_shard": block.mlp.W.weight.to_local().cpu(), + "tp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP), + "pp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP), + "dp_shard_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.DP_SHARD), + "block_key": block_key, + } + + gather_list: list[dict[str, Any]] | None = [None] * world_size if dist.get_rank() == 0 else None + dist.gather_object(payload, gather_list, dst=0) + + if dist.get_rank() == 0: + assert gather_list is not None + TestParallelSeedInitialization._assert_parameter_distribution(gather_list) + dist.barrier() + + @staticmethod + def _assert_parameter_distribution(records: list[dict[str, Any]]): + combos: dict[tuple[int, int], list[dict[str, Any]]] = {} + for record in records: + key = (record["pp_rank"], record["tp_rank"]) + combos.setdefault(key, []).append(record) + + expected_combo_count = 4 + assert ( + len(combos) == expected_combo_count + ), f"Expected {expected_combo_count} PP/TP combinations, got {len(combos)}" + + combo_tensors: dict[tuple[int, int], torch.Tensor] = {} + for (pp_rank, tp_rank), entries in combos.items(): + shards = sorted(entries, key=lambda e: e["dp_shard_rank"]) + combo_tensors[(pp_rank, tp_rank)] = torch.cat( + [e["tensor_shard"] for e in shards], + dim=0, + ) + + combo_items = list(combo_tensors.items()) + for idx, ((pp_rank, tp_rank), base_tensor) in enumerate(combo_items): + for other_key, other_tensor in combo_items[idx + 1 :]: + tensors_equal = torch.equal(base_tensor, other_tensor) + assert not tensors_equal, ( + "Distinct TP/PP combinations should initialize with different weights; " + f"found match between (PP={pp_rank}, TP={tp_rank}) and (PP={other_key[0]}, TP={other_key[1]})" + ) + + def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degree: int, tmp_path: Path) -> Path: + temp_file_path = tmp_path / "pp_tp_sharding_config.yaml" + working_dir = Path(os.path.dirname(__file__)) + config_file_path = ( + working_dir / "pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml" + ) + + with open(config_file_path, "r") as file: + config_string = file.read() + config_dict = yaml.safe_load(config_string) + config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = dp_degree + config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree + config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree + + # save to temporary file + with open(temp_file_path, "w") as file: + yaml.dump(config_dict, file) + + return temp_file_path + + +def _get_loss_scores(messages: list[Message[EvaluationResultBatch]], loss_key: str) -> list[float]: + return [message.payload.losses[loss_key].value.item() for message in messages] + + +def _extract_seen_steps_and_tokens(filename: str) -> tuple[int, int]: + pattern = r"seen_steps_(\d+)-seen_tokens_(\d+)" + match = re.search(pattern, filename) + if match is None: + raise ValueError(f"Filename '{filename}' does not match expected pattern '{pattern}'.") + return int(match.group(1)), int(match.group(2)) From 00a595bfcbf136f11eebc9a808bee09a304aaa51 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 12 Dec 2025 17:49:32 +0000 Subject: [PATCH 05/10] test: Check for equal parameters across data parallel processes --- .../test_parallel_seed_initialization.py | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py index 58f0d10c5..45b79327f 100644 --- a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -12,13 +12,14 @@ import torch.multiprocessing as mp import yaml from pydantic import BaseModel +from torch.distributed._tensor.placement_types import Replicate from modalities.__main__ import Main from modalities.batch import EvaluationResultBatch from modalities.config.config import ProcessGroupBackendType from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType from modalities.logging_broker.messages import Message -from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_device_mesh, get_parallel_rank +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank from tests.end2end_tests.custom_components import MultiProcessingCudaEnv from tests.utility import monitor_child_processes @@ -67,17 +68,6 @@ def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_ os._exit(1) def _seed_distribution_impl(self, world_size: int, tmp_path: Path): - device_mesh = get_device_mesh( - device_type="cuda", - data_parallel_replicate_degree=2, - data_parallel_shard_degree=1, - tensor_parallel_degree=2, - pipeline_parallel_degree=2, - context_parallel_degree=1, - enable_loss_parallel=False, - world_size=world_size, - ) - # initialize components class ComponentsInstantiationModel(BaseModel): fsdp_model: PydanticFSDP2ModuleType @@ -88,10 +78,13 @@ class ComponentsInstantiationModel(BaseModel): components = main_obj.build_components(components_model_type=ComponentsInstantiationModel) model = components.fsdp_model device_mesh = components.device_mesh - # get first transformer block's MLP weight parameter shards + # for each pp stage get first transformer block's MLP weight parameter shards and full tensor block_key = next(iter(model.transformer.h.keys())) block = model.transformer.h[block_key] + placements = [Replicate()] * len(block.mlp.W.weight.device_mesh.mesh.shape) + full_weight = block.mlp.W.weight.redistribute(placements=placements).to_local().cpu() payload = { + "tensor_full": full_weight, "tensor_shard": block.mlp.W.weight.to_local().cpu(), "tp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP), "pp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP), @@ -121,12 +114,24 @@ def _assert_parameter_distribution(records: list[dict[str, Any]]): combo_tensors: dict[tuple[int, int], torch.Tensor] = {} for (pp_rank, tp_rank), entries in combos.items(): + # check that full tensors are the same across data parallel processes + reference = entries[0]["tensor_full"] + seen_dp_ranks: set[int] = set() + for entry in entries: + dp_rank = entry["dp_shard_rank"] + assert dp_rank not in seen_dp_ranks, f"Duplicate DP rank {dp_rank} for combo PP={pp_rank}, TP={tp_rank}" + seen_dp_ranks.add(dp_rank) + assert torch.equal(reference, entry["tensor_full"]), ( + "Tensors within the same TP/PP combo must be identical across DP ranks; " + f"mismatch at DP rank {dp_rank} for (PP={pp_rank}, TP={tp_rank})" + ) + # concatenate all shards for this pp/tp combo shards = sorted(entries, key=lambda e: e["dp_shard_rank"]) combo_tensors[(pp_rank, tp_rank)] = torch.cat( [e["tensor_shard"] for e in shards], dim=0, ) - + # check that tensor shards differ across different pp/tp combos combo_items = list(combo_tensors.items()) for idx, ((pp_rank, tp_rank), base_tensor) in enumerate(combo_items): for other_key, other_tensor in combo_items[idx + 1 :]: From bf06da7bf67997e20f72949cab849597ad0d7508 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Dec 2025 13:26:49 +0000 Subject: [PATCH 06/10] feat: Integrate seeding to model initialization --- .../composed_initialization.py | 23 ++++++++++++++-- .../initialization_routines.py | 27 ++++++++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index 190311cb6..b1b976573 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -2,6 +2,7 @@ import torch.nn as nn from pydantic import BaseModel, ConfigDict, Field, model_validator +from torch.distributed.device_mesh import DeviceMesh from typing_extensions import Annotated from modalities.config.pydantic_if_types import PydanticModelInitializationIFType @@ -12,6 +13,7 @@ SupportWeightInitModels, WeightInitTypes, ) +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method class ModelInitializerWrapperConfig(BaseModel): @@ -100,6 +102,8 @@ def get_composed_model_initializer( std: float | str, hidden_dim: Optional[int] = None, num_layers: int = None, + device_mesh: Optional[DeviceMesh] = None, + seed: Optional[int] = None, ) -> ModelInitializationIF: """This initialization allows to intialize a model with plain, scaled or scaled_embed initialization. Note that plain initialization is always performed in the beginning. In case of scaled_embed, @@ -114,16 +118,28 @@ def get_composed_model_initializer( Defaults to None. num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only). Defaults to None. + device_mesh (Optional[DeviceMesh], optional): Device mesh used for parallelization. + seed (Optional[int], optional): Seed for random initialization. Defaults to None. Returns: ModelInitializationIF: The Weight Initializer performing the initialization as specified. """ + # Set different random seed for each PP rank to ensure diversity + if seed is not None and has_parallelism_method( + device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP + ): + seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) + model_initializers = [] # plain plain_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.PLAIN] plain_init = InitializationRoutines.get_plain_initialization( - mean=mean, std=std, hidden_dim=hidden_dim, parameter_name_regexes=plain_parameter_name_regexes + mean=mean, + std=std, + hidden_dim=hidden_dim, + parameter_name_regexes=plain_parameter_name_regexes, + seed=seed, ) working_std = plain_init.std model_initializers.append(plain_init) @@ -136,6 +152,7 @@ def get_composed_model_initializer( std=working_std, num_layers=num_layers, parameter_name_regexes=scaled_parameter_name_regexes, + seed=seed, ) model_initializers.append(scaled_init) @@ -143,7 +160,9 @@ def get_composed_model_initializer( # scaled embed scaled_embed_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED_EMBED] scaled_embed_init = InitializationRoutines.get_scaled_embed_initialization( - mean=mean, parameter_name_regexes=scaled_embed_parameter_name_regexes + mean=mean, + parameter_name_regexes=scaled_embed_parameter_name_regexes, + seed=seed, ) model_initializers.append(scaled_embed_init) diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index 5b4515875..36953d646 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -2,6 +2,7 @@ import re from typing import Annotated, Optional +import torch import torch.nn as nn from pydantic import BaseModel, Field, model_validator @@ -59,7 +60,11 @@ def initialize_in_place(self, model: nn.Module): class InitializationRoutines: @staticmethod def get_plain_initialization( - mean: float, std: float | str, parameter_name_regexes: list[str], hidden_dim: Optional[int] = None + mean: float, + std: float | str, + parameter_name_regexes: list[str], + hidden_dim: Optional[int] = None, + seed: Optional[int] = None, ) -> NamedParameterwiseNormalInitialization: """Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. @@ -70,8 +75,11 @@ def get_plain_initialization( std (float): standard deviation of the normal distribution. If set to "auto", appropiate value selected as per plain initialization described in https://arxiv.org/abs/2312.16903 hidden_dim (Optional[int]): hidden dimension of the attention layer. Defaults to None. + parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization + should be applied + seed (Optional[int]): Random seed for initialization. Defaults to None. """ - + InitializationRoutines._set_seed(seed) # auto: choose std automatically if std == "auto": if hidden_dim is None: @@ -86,7 +94,7 @@ def get_plain_initialization( @staticmethod def get_scaled_initialization( - mean: float, std: float, num_layers: int, parameter_name_regexes: list[str] + mean: float, std: float, num_layers: int, parameter_name_regexes: list[str], seed: Optional[int] = None ) -> ModelInitializationIF: """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 @@ -96,10 +104,12 @@ def get_scaled_initialization( num_layers (int): Number of layers in the model which we use to downscale std with parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization should be applied + seed (Optional[int]): Random seed for initialization. Defaults to None. Returns: WeightInitializationIF: Weight initialization object """ + InitializationRoutines._set_seed(seed) # see https://arxiv.org/abs/2312.16903 scaled_std = std / math.sqrt(2 * num_layers) @@ -109,7 +119,9 @@ def get_scaled_initialization( return initialization @staticmethod - def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[str]) -> ModelInitializationIF: + def get_scaled_embed_initialization( + mean: float, parameter_name_regexes: list[str], seed: Optional[int] = None + ) -> ModelInitializationIF: """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4). @@ -117,12 +129,19 @@ def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[st mean (float): Mean of the normal distribution parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization should be applied Defaults to None. + seed (Optional[int]): Random seed for initialization. Defaults to None. Returns: WeightInitializationIF: Weight initialization object """ + InitializationRoutines._set_seed(seed) std = math.sqrt(0.4) initialization = NamedParameterwiseNormalInitialization( mean=mean, std=std, parameter_name_regexes=parameter_name_regexes ) return initialization + + @staticmethod + def _set_seed(seed: Optional[int]): + if seed is not None: + torch.manual_seed(seed) From b137701774c5e0bae687231865091a3c4a39b01d Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Dec 2025 13:37:01 +0000 Subject: [PATCH 07/10] refactor: Move seeding logic to model initialization component --- src/modalities/models/gpt2/gpt2_model.py | 11 ----------- src/modalities/models/model.py | 2 +- src/modalities/models/model_factory.py | 2 -- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index c8f82ecf6..0a846b38a 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -7,10 +7,8 @@ import torch import torch.nn as nn from pydantic import BaseModel, Field, model_validator, validator -from torch.distributed.device_mesh import DeviceMesh from modalities.config.lookup_enum import LookupEnum -from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType from modalities.config.utils import convert_base_model_config_to_dict from modalities.models.components.layer_norms import ( LayerNormConfig, @@ -19,7 +17,6 @@ RMSLayerNormConfig, ) from modalities.models.model import ActivationType, NNModel, SwiGLU -from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method from modalities.util import parse_enum_by_name try: @@ -370,7 +367,6 @@ class GPT2LLMConfig(BaseModel): use_weight_tying: bool seed: Optional[int] = None enforce_swiglu_hidden_dim_multiple_of: int = 256 - device_mesh: Optional[PydanticDeviceMeshIFType] = None @model_validator(mode="after") def check_divisibility(self) -> "GPT2LLMConfig": @@ -838,7 +834,6 @@ def __init__( use_weight_tying: bool, seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, - device_mesh: DeviceMesh | None = None, ): """ Initializes the GPT2LLM object. @@ -867,18 +862,12 @@ def __init__( enforce_swiglu_hidden_dim_multiple_of (int): Enforces the hidden dimension in the SwiGLU layer to be a multiple of this value. Note that this is only relevant if the activation_type is SwiGLU. Defaults to 256. - device_mesh (DeviceMesh | None): The device mesh for parallelism. Defaults to None. """ weight_decay_groups = { "linear": [".attn", ".mlp", ".lm_head.weight"], "embedding": [".wte", ".wpe"], "layernorm": [".attention_norm", ".ffn_norm", ".lm_head_norm"], } - # Set different random seed for each PP rank to ensure diversity - if seed is not None and has_parallelism_method( - device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP - ): - seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) self.sample_key = sample_key self.prediction_key = prediction_key diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index ac3dca96b..5dc7986b2 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -26,7 +26,7 @@ class ActivationType(str, Enum): class NNModel(nn.Module): """NNModel class to define a base model.""" - def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None): + def __init__(self, seed: Optional[int] = None, weight_decay_groups: Optional[WeightDecayGroups] = None): """ Initializes an NNModel object. diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index c4c953eaf..3acb17f95 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -578,7 +578,6 @@ def get_gpt2_model( use_meta_device: Optional[bool] = False, seed: Optional[int] = None, enforce_swiglu_hidden_dim_multiple_of: int = 256, - device_mesh: DeviceMesh | None = None, ) -> GPT2LLM: config = dict( sample_key=sample_key, @@ -602,7 +601,6 @@ def get_gpt2_model( seed=seed, use_weight_tying=use_weight_tying, enforce_swiglu_hidden_dim_multiple_of=enforce_swiglu_hidden_dim_multiple_of, - device_mesh=device_mesh, ) if use_meta_device and use_weight_tying: raise ValueError( From bff99f3cb7880e34df303f8764ea8dcad1aaa49b Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Dec 2025 14:01:10 +0000 Subject: [PATCH 08/10] chore: Add seed and device_mesh to ComposedModelInitializationConfig --- .../nn/model_initialization/composed_initialization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index b1b976573..1789011f1 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -5,7 +5,7 @@ from torch.distributed.device_mesh import DeviceMesh from typing_extensions import Annotated -from modalities.config.pydantic_if_types import PydanticModelInitializationIFType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticModelInitializationIFType from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.nn.model_initialization.initialization_routines import InitializationRoutines from modalities.nn.model_initialization.parameter_name_filters import ( @@ -32,6 +32,8 @@ class ComposedModelInitializationConfig(BaseModel): std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None + seed: Optional[int] = None + device_mesh: Optional[PydanticDeviceMeshIFType] = None # avoid warning about protected namespace 'model_', see # https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces From 98ff9db1479c2da02c826bba8acf40946f677c62 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Dec 2025 14:09:18 +0000 Subject: [PATCH 09/10] test: Adapt test to latest changes --- .../config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml index fb8ee5f7d..8fe1d5472 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml @@ -129,7 +129,11 @@ initialized_model: weight_init_type: scaled mean: 0.0 std: 0.02 + seed: 42 num_layers: ${model_raw.config.n_layer} + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE model_raw: component_key: model From 2e248ed2477bfa64a3e91045da222a0cc4c86f35 Mon Sep 17 00:00:00 2001 From: rrutmann Date: Fri, 19 Dec 2025 14:22:47 +0000 Subject: [PATCH 10/10] chore: Remove old code --- .../test_parallel_seed_initialization.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py index 45b79327f..b9bb2f7ca 100644 --- a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -1,7 +1,6 @@ import logging import multiprocessing as py_mp import os -import re import traceback from pathlib import Path from typing import Any @@ -15,10 +14,8 @@ from torch.distributed._tensor.placement_types import Replicate from modalities.__main__ import Main -from modalities.batch import EvaluationResultBatch from modalities.config.config import ProcessGroupBackendType from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType -from modalities.logging_broker.messages import Message from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank from tests.end2end_tests.custom_components import MultiProcessingCudaEnv from tests.utility import monitor_child_processes @@ -160,15 +157,3 @@ def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degre yaml.dump(config_dict, file) return temp_file_path - - -def _get_loss_scores(messages: list[Message[EvaluationResultBatch]], loss_key: str) -> list[float]: - return [message.payload.losses[loss_key].value.item() for message in messages] - - -def _extract_seen_steps_and_tokens(filename: str) -> tuple[int, int]: - pattern = r"seen_steps_(\d+)-seen_tokens_(\d+)" - match = re.search(pattern, filename) - if match is None: - raise ValueError(f"Filename '{filename}' does not match expected pattern '{pattern}'.") - return int(match.group(1)), int(match.group(2))