diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index ac3dca96b..5dc7986b2 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -26,7 +26,7 @@ class ActivationType(str, Enum): class NNModel(nn.Module): """NNModel class to define a base model.""" - def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None): + def __init__(self, seed: Optional[int] = None, weight_decay_groups: Optional[WeightDecayGroups] = None): """ Initializes an NNModel object. diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index 190311cb6..1789011f1 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -2,9 +2,10 @@ import torch.nn as nn from pydantic import BaseModel, ConfigDict, Field, model_validator +from torch.distributed.device_mesh import DeviceMesh from typing_extensions import Annotated -from modalities.config.pydantic_if_types import PydanticModelInitializationIFType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticModelInitializationIFType from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.nn.model_initialization.initialization_routines import InitializationRoutines from modalities.nn.model_initialization.parameter_name_filters import ( @@ -12,6 +13,7 @@ SupportWeightInitModels, WeightInitTypes, ) +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method class ModelInitializerWrapperConfig(BaseModel): @@ -30,6 +32,8 @@ class ComposedModelInitializationConfig(BaseModel): std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None + seed: Optional[int] = None + device_mesh: Optional[PydanticDeviceMeshIFType] = None # avoid warning about protected namespace 'model_', see # https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces @@ -100,6 +104,8 @@ def get_composed_model_initializer( std: float | str, hidden_dim: Optional[int] = None, num_layers: int = None, + device_mesh: Optional[DeviceMesh] = None, + seed: Optional[int] = None, ) -> ModelInitializationIF: """This initialization allows to intialize a model with plain, scaled or scaled_embed initialization. Note that plain initialization is always performed in the beginning. In case of scaled_embed, @@ -114,16 +120,28 @@ def get_composed_model_initializer( Defaults to None. num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only). Defaults to None. + device_mesh (Optional[DeviceMesh], optional): Device mesh used for parallelization. + seed (Optional[int], optional): Seed for random initialization. Defaults to None. Returns: ModelInitializationIF: The Weight Initializer performing the initialization as specified. """ + # Set different random seed for each PP rank to ensure diversity + if seed is not None and has_parallelism_method( + device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP + ): + seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP) + model_initializers = [] # plain plain_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.PLAIN] plain_init = InitializationRoutines.get_plain_initialization( - mean=mean, std=std, hidden_dim=hidden_dim, parameter_name_regexes=plain_parameter_name_regexes + mean=mean, + std=std, + hidden_dim=hidden_dim, + parameter_name_regexes=plain_parameter_name_regexes, + seed=seed, ) working_std = plain_init.std model_initializers.append(plain_init) @@ -136,6 +154,7 @@ def get_composed_model_initializer( std=working_std, num_layers=num_layers, parameter_name_regexes=scaled_parameter_name_regexes, + seed=seed, ) model_initializers.append(scaled_init) @@ -143,7 +162,9 @@ def get_composed_model_initializer( # scaled embed scaled_embed_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED_EMBED] scaled_embed_init = InitializationRoutines.get_scaled_embed_initialization( - mean=mean, parameter_name_regexes=scaled_embed_parameter_name_regexes + mean=mean, + parameter_name_regexes=scaled_embed_parameter_name_regexes, + seed=seed, ) model_initializers.append(scaled_embed_init) diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index 5b4515875..36953d646 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -2,6 +2,7 @@ import re from typing import Annotated, Optional +import torch import torch.nn as nn from pydantic import BaseModel, Field, model_validator @@ -59,7 +60,11 @@ def initialize_in_place(self, model: nn.Module): class InitializationRoutines: @staticmethod def get_plain_initialization( - mean: float, std: float | str, parameter_name_regexes: list[str], hidden_dim: Optional[int] = None + mean: float, + std: float | str, + parameter_name_regexes: list[str], + hidden_dim: Optional[int] = None, + seed: Optional[int] = None, ) -> NamedParameterwiseNormalInitialization: """Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. @@ -70,8 +75,11 @@ def get_plain_initialization( std (float): standard deviation of the normal distribution. If set to "auto", appropiate value selected as per plain initialization described in https://arxiv.org/abs/2312.16903 hidden_dim (Optional[int]): hidden dimension of the attention layer. Defaults to None. + parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization + should be applied + seed (Optional[int]): Random seed for initialization. Defaults to None. """ - + InitializationRoutines._set_seed(seed) # auto: choose std automatically if std == "auto": if hidden_dim is None: @@ -86,7 +94,7 @@ def get_plain_initialization( @staticmethod def get_scaled_initialization( - mean: float, std: float, num_layers: int, parameter_name_regexes: list[str] + mean: float, std: float, num_layers: int, parameter_name_regexes: list[str], seed: Optional[int] = None ) -> ModelInitializationIF: """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 @@ -96,10 +104,12 @@ def get_scaled_initialization( num_layers (int): Number of layers in the model which we use to downscale std with parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization should be applied + seed (Optional[int]): Random seed for initialization. Defaults to None. Returns: WeightInitializationIF: Weight initialization object """ + InitializationRoutines._set_seed(seed) # see https://arxiv.org/abs/2312.16903 scaled_std = std / math.sqrt(2 * num_layers) @@ -109,7 +119,9 @@ def get_scaled_initialization( return initialization @staticmethod - def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[str]) -> ModelInitializationIF: + def get_scaled_embed_initialization( + mean: float, parameter_name_regexes: list[str], seed: Optional[int] = None + ) -> ModelInitializationIF: """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4). @@ -117,12 +129,19 @@ def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[st mean (float): Mean of the normal distribution parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization should be applied Defaults to None. + seed (Optional[int]): Random seed for initialization. Defaults to None. Returns: WeightInitializationIF: Weight initialization object """ + InitializationRoutines._set_seed(seed) std = math.sqrt(0.4) initialization = NamedParameterwiseNormalInitialization( mean=mean, std=std, parameter_name_regexes=parameter_name_regexes ) return initialization + + @staticmethod + def _set_seed(seed: Optional[int]): + if seed is not None: + torch.manual_seed(seed) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml index fb8ee5f7d..8fe1d5472 100644 --- a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml @@ -129,7 +129,11 @@ initialized_model: weight_init_type: scaled mean: 0.0 std: 0.02 + seed: 42 num_layers: ${model_raw.config.n_layer} + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE model_raw: component_key: model diff --git a/tests/fsdp2_parallelization/test_parallel_seed_initialization.py b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py new file mode 100644 index 000000000..b9bb2f7ca --- /dev/null +++ b/tests/fsdp2_parallelization/test_parallel_seed_initialization.py @@ -0,0 +1,159 @@ +import logging +import multiprocessing as py_mp +import os +import traceback +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import yaml +from pydantic import BaseModel +from torch.distributed._tensor.placement_types import Replicate + +from modalities.__main__ import Main +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank +from tests.end2end_tests.custom_components import MultiProcessingCudaEnv +from tests.utility import monitor_child_processes + +working_dir = Path(os.path.dirname(__file__)) +tmp_folder = working_dir / "../tmp/fsdp2_warmstart_pp_tp" +working_dir = working_dir / "configs" + + +@pytest.mark.skipif( + torch.cuda.device_count() < 8, + reason="This e2e test requires 8 GPUs.", +) +class TestParallelSeedInitialization: + WORLD_SIZE = 8 + RDVZ_PORT = 24574 + + def test_parameters_follow_parallelism(self, tmp_path: Path): + manager = py_mp.Manager() + error_queue = manager.Queue() + proc_ctx = mp.spawn( + self._seed_distribution_impl_wrapper, + args=(self.WORLD_SIZE, tmp_path, error_queue), + nprocs=self.WORLD_SIZE, + join=False, + ) + monitor_child_processes(manager, error_queue, proc_ctx) + + def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_path: Path, error_queue: Any): + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=TestParallelSeedInitialization.RDVZ_PORT, + ): + try: + self._seed_distribution_impl(world_size=world_size, tmp_path=tmp_path) + except Exception as exc: + tb = traceback.format_exc() + logging.error(f"Process {process_id} (seed distribution test) encountered an error:\n{exc}") + logging.error(tb) + try: + error_queue.put((process_id, tb)) + except Exception: + logging.error("Failed to put exception info into error queue (seed distribution test).") + os._exit(1) + + def _seed_distribution_impl(self, world_size: int, tmp_path: Path): + # initialize components + class ComponentsInstantiationModel(BaseModel): + fsdp_model: PydanticFSDP2ModuleType + device_mesh: PydanticDeviceMeshIFType + + config_file_path = self._get_tmp_sharding_config_path(dp_degree=2, tp_degree=2, pp_degree=2, tmp_path=tmp_path) + main_obj = Main(config_file_path) + components = main_obj.build_components(components_model_type=ComponentsInstantiationModel) + model = components.fsdp_model + device_mesh = components.device_mesh + # for each pp stage get first transformer block's MLP weight parameter shards and full tensor + block_key = next(iter(model.transformer.h.keys())) + block = model.transformer.h[block_key] + placements = [Replicate()] * len(block.mlp.W.weight.device_mesh.mesh.shape) + full_weight = block.mlp.W.weight.redistribute(placements=placements).to_local().cpu() + payload = { + "tensor_full": full_weight, + "tensor_shard": block.mlp.W.weight.to_local().cpu(), + "tp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP), + "pp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP), + "dp_shard_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.DP_SHARD), + "block_key": block_key, + } + + gather_list: list[dict[str, Any]] | None = [None] * world_size if dist.get_rank() == 0 else None + dist.gather_object(payload, gather_list, dst=0) + + if dist.get_rank() == 0: + assert gather_list is not None + TestParallelSeedInitialization._assert_parameter_distribution(gather_list) + dist.barrier() + + @staticmethod + def _assert_parameter_distribution(records: list[dict[str, Any]]): + combos: dict[tuple[int, int], list[dict[str, Any]]] = {} + for record in records: + key = (record["pp_rank"], record["tp_rank"]) + combos.setdefault(key, []).append(record) + + expected_combo_count = 4 + assert ( + len(combos) == expected_combo_count + ), f"Expected {expected_combo_count} PP/TP combinations, got {len(combos)}" + + combo_tensors: dict[tuple[int, int], torch.Tensor] = {} + for (pp_rank, tp_rank), entries in combos.items(): + # check that full tensors are the same across data parallel processes + reference = entries[0]["tensor_full"] + seen_dp_ranks: set[int] = set() + for entry in entries: + dp_rank = entry["dp_shard_rank"] + assert dp_rank not in seen_dp_ranks, f"Duplicate DP rank {dp_rank} for combo PP={pp_rank}, TP={tp_rank}" + seen_dp_ranks.add(dp_rank) + assert torch.equal(reference, entry["tensor_full"]), ( + "Tensors within the same TP/PP combo must be identical across DP ranks; " + f"mismatch at DP rank {dp_rank} for (PP={pp_rank}, TP={tp_rank})" + ) + # concatenate all shards for this pp/tp combo + shards = sorted(entries, key=lambda e: e["dp_shard_rank"]) + combo_tensors[(pp_rank, tp_rank)] = torch.cat( + [e["tensor_shard"] for e in shards], + dim=0, + ) + # check that tensor shards differ across different pp/tp combos + combo_items = list(combo_tensors.items()) + for idx, ((pp_rank, tp_rank), base_tensor) in enumerate(combo_items): + for other_key, other_tensor in combo_items[idx + 1 :]: + tensors_equal = torch.equal(base_tensor, other_tensor) + assert not tensors_equal, ( + "Distinct TP/PP combinations should initialize with different weights; " + f"found match between (PP={pp_rank}, TP={tp_rank}) and (PP={other_key[0]}, TP={other_key[1]})" + ) + + def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degree: int, tmp_path: Path) -> Path: + temp_file_path = tmp_path / "pp_tp_sharding_config.yaml" + working_dir = Path(os.path.dirname(__file__)) + config_file_path = ( + working_dir / "pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml" + ) + + with open(config_file_path, "r") as file: + config_string = file.read() + config_dict = yaml.safe_load(config_string) + config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = dp_degree + config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree + config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree + + # save to temporary file + with open(temp_file_path, "w") as file: + yaml.dump(config_dict, file) + + return temp_file_path