Skip to content
2 changes: 1 addition & 1 deletion src/modalities/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ActivationType(str, Enum):
class NNModel(nn.Module):
"""NNModel class to define a base model."""

def __init__(self, seed: int = None, weight_decay_groups: Optional[WeightDecayGroups] = None):
def __init__(self, seed: Optional[int] = None, weight_decay_groups: Optional[WeightDecayGroups] = None):
"""
Initializes an NNModel object.

Expand Down
27 changes: 24 additions & 3 deletions src/modalities/nn/model_initialization/composed_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

import torch.nn as nn
from pydantic import BaseModel, ConfigDict, Field, model_validator
from torch.distributed.device_mesh import DeviceMesh
from typing_extensions import Annotated

from modalities.config.pydantic_if_types import PydanticModelInitializationIFType
from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticModelInitializationIFType
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.nn.model_initialization.initialization_routines import InitializationRoutines
from modalities.nn.model_initialization.parameter_name_filters import (
NAMED_PARAMETER_INIT_GROUPS,
SupportWeightInitModels,
WeightInitTypes,
)
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank, has_parallelism_method


class ModelInitializerWrapperConfig(BaseModel):
Expand All @@ -30,6 +32,8 @@ class ComposedModelInitializationConfig(BaseModel):
std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto"
hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
seed: Optional[int] = None
device_mesh: Optional[PydanticDeviceMeshIFType] = None

# avoid warning about protected namespace 'model_', see
# https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
Expand Down Expand Up @@ -100,6 +104,8 @@ def get_composed_model_initializer(
std: float | str,
hidden_dim: Optional[int] = None,
num_layers: int = None,
device_mesh: Optional[DeviceMesh] = None,
seed: Optional[int] = None,
) -> ModelInitializationIF:
"""This initialization allows to intialize a model with plain, scaled or scaled_embed initialization.
Note that plain initialization is always performed in the beginning. In case of scaled_embed,
Expand All @@ -114,16 +120,28 @@ def get_composed_model_initializer(
Defaults to None.
num_layers (int, optional): Number of layers in the model (required for scaled and scaled_embed only).
Defaults to None.
device_mesh (Optional[DeviceMesh], optional): Device mesh used for parallelization.
seed (Optional[int], optional): Seed for random initialization. Defaults to None.

Returns:
ModelInitializationIF: The Weight Initializer performing the initialization as specified.
"""
# Set different random seed for each PP rank to ensure diversity
if seed is not None and has_parallelism_method(
device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP
):
seed += get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP)

model_initializers = []

# plain
plain_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.PLAIN]
plain_init = InitializationRoutines.get_plain_initialization(
mean=mean, std=std, hidden_dim=hidden_dim, parameter_name_regexes=plain_parameter_name_regexes
mean=mean,
std=std,
hidden_dim=hidden_dim,
parameter_name_regexes=plain_parameter_name_regexes,
seed=seed,
)
working_std = plain_init.std
model_initializers.append(plain_init)
Expand All @@ -136,14 +154,17 @@ def get_composed_model_initializer(
std=working_std,
num_layers=num_layers,
parameter_name_regexes=scaled_parameter_name_regexes,
seed=seed,
)
model_initializers.append(scaled_init)

if weight_init_type == WeightInitTypes.SCALED_EMBED:
# scaled embed
scaled_embed_parameter_name_regexes = NAMED_PARAMETER_INIT_GROUPS[model_type][WeightInitTypes.SCALED_EMBED]
scaled_embed_init = InitializationRoutines.get_scaled_embed_initialization(
mean=mean, parameter_name_regexes=scaled_embed_parameter_name_regexes
mean=mean,
parameter_name_regexes=scaled_embed_parameter_name_regexes,
seed=seed,
)
model_initializers.append(scaled_embed_init)

Expand Down
27 changes: 23 additions & 4 deletions src/modalities/nn/model_initialization/initialization_routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from typing import Annotated, Optional

import torch
import torch.nn as nn
from pydantic import BaseModel, Field, model_validator

Expand Down Expand Up @@ -59,7 +60,11 @@ def initialize_in_place(self, model: nn.Module):
class InitializationRoutines:
@staticmethod
def get_plain_initialization(
mean: float, std: float | str, parameter_name_regexes: list[str], hidden_dim: Optional[int] = None
mean: float,
std: float | str,
parameter_name_regexes: list[str],
hidden_dim: Optional[int] = None,
seed: Optional[int] = None,
) -> NamedParameterwiseNormalInitialization:
"""Initializes the weights of a model by sampling from a normal distribution.
NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers.
Expand All @@ -70,8 +75,11 @@ def get_plain_initialization(
std (float): standard deviation of the normal distribution. If set to "auto", appropiate
value selected as per plain initialization described in https://arxiv.org/abs/2312.16903
hidden_dim (Optional[int]): hidden dimension of the attention layer. Defaults to None.
parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization
should be applied
seed (Optional[int]): Random seed for initialization. Defaults to None.
"""

InitializationRoutines._set_seed(seed)
# auto: choose std automatically
if std == "auto":
if hidden_dim is None:
Expand All @@ -86,7 +94,7 @@ def get_plain_initialization(

@staticmethod
def get_scaled_initialization(
mean: float, std: float, num_layers: int, parameter_name_regexes: list[str]
mean: float, std: float, num_layers: int, parameter_name_regexes: list[str], seed: Optional[int] = None
) -> ModelInitializationIF:
"""Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903

Expand All @@ -96,10 +104,12 @@ def get_scaled_initialization(
num_layers (int): Number of layers in the model which we use to downscale std with
parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization
should be applied
seed (Optional[int]): Random seed for initialization. Defaults to None.

Returns:
WeightInitializationIF: Weight initialization object
"""
InitializationRoutines._set_seed(seed)
# see https://arxiv.org/abs/2312.16903
scaled_std = std / math.sqrt(2 * num_layers)

Expand All @@ -109,20 +119,29 @@ def get_scaled_initialization(
return initialization

@staticmethod
def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[str]) -> ModelInitializationIF:
def get_scaled_embed_initialization(
mean: float, parameter_name_regexes: list[str], seed: Optional[int] = None
) -> ModelInitializationIF:
"""Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903
We fix the standard deviation to sqrt(0.4).

Args:
mean (float): Mean of the normal distribution
parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization
should be applied Defaults to None.
seed (Optional[int]): Random seed for initialization. Defaults to None.

Returns:
WeightInitializationIF: Weight initialization object
"""
InitializationRoutines._set_seed(seed)
std = math.sqrt(0.4)
initialization = NamedParameterwiseNormalInitialization(
mean=mean, std=std, parameter_name_regexes=parameter_name_regexes
)
return initialization

@staticmethod
def _set_seed(seed: Optional[int]):
if seed is not None:
torch.manual_seed(seed)
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ initialized_model:
weight_init_type: scaled
mean: 0.0
std: 0.02
seed: 42
num_layers: ${model_raw.config.n_layer}
device_mesh:
instance_key: device_mesh
pass_type: BY_REFERENCE

model_raw:
component_key: model
Expand Down
159 changes: 159 additions & 0 deletions tests/fsdp2_parallelization/test_parallel_seed_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import logging
import multiprocessing as py_mp
import os
import traceback
from pathlib import Path
from typing import Any

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import yaml
from pydantic import BaseModel
from torch.distributed._tensor.placement_types import Replicate

from modalities.__main__ import Main
from modalities.config.config import ProcessGroupBackendType
from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank
from tests.end2end_tests.custom_components import MultiProcessingCudaEnv
from tests.utility import monitor_child_processes

working_dir = Path(os.path.dirname(__file__))
tmp_folder = working_dir / "../tmp/fsdp2_warmstart_pp_tp"
working_dir = working_dir / "configs"


@pytest.mark.skipif(
torch.cuda.device_count() < 8,
reason="This e2e test requires 8 GPUs.",
)
class TestParallelSeedInitialization:
WORLD_SIZE = 8
RDVZ_PORT = 24574

def test_parameters_follow_parallelism(self, tmp_path: Path):
manager = py_mp.Manager()
error_queue = manager.Queue()
proc_ctx = mp.spawn(
self._seed_distribution_impl_wrapper,
args=(self.WORLD_SIZE, tmp_path, error_queue),
nprocs=self.WORLD_SIZE,
join=False,
)
monitor_child_processes(manager, error_queue, proc_ctx)

def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_path: Path, error_queue: Any):
with MultiProcessingCudaEnv(
process_group_backend=ProcessGroupBackendType.nccl,
global_rank=process_id,
local_rank=process_id,
world_size=world_size,
rdvz_port=TestParallelSeedInitialization.RDVZ_PORT,
):
try:
self._seed_distribution_impl(world_size=world_size, tmp_path=tmp_path)
except Exception as exc:
tb = traceback.format_exc()
logging.error(f"Process {process_id} (seed distribution test) encountered an error:\n{exc}")
logging.error(tb)
try:
error_queue.put((process_id, tb))
except Exception:
logging.error("Failed to put exception info into error queue (seed distribution test).")
os._exit(1)

def _seed_distribution_impl(self, world_size: int, tmp_path: Path):
# initialize components
class ComponentsInstantiationModel(BaseModel):
fsdp_model: PydanticFSDP2ModuleType
device_mesh: PydanticDeviceMeshIFType

config_file_path = self._get_tmp_sharding_config_path(dp_degree=2, tp_degree=2, pp_degree=2, tmp_path=tmp_path)
main_obj = Main(config_file_path)
components = main_obj.build_components(components_model_type=ComponentsInstantiationModel)
model = components.fsdp_model
device_mesh = components.device_mesh
# for each pp stage get first transformer block's MLP weight parameter shards and full tensor
block_key = next(iter(model.transformer.h.keys()))
block = model.transformer.h[block_key]
placements = [Replicate()] * len(block.mlp.W.weight.device_mesh.mesh.shape)
full_weight = block.mlp.W.weight.redistribute(placements=placements).to_local().cpu()
payload = {
"tensor_full": full_weight,
"tensor_shard": block.mlp.W.weight.to_local().cpu(),
"tp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP),
"pp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP),
"dp_shard_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.DP_SHARD),
"block_key": block_key,
}

gather_list: list[dict[str, Any]] | None = [None] * world_size if dist.get_rank() == 0 else None
dist.gather_object(payload, gather_list, dst=0)

if dist.get_rank() == 0:
assert gather_list is not None
TestParallelSeedInitialization._assert_parameter_distribution(gather_list)
dist.barrier()

@staticmethod
def _assert_parameter_distribution(records: list[dict[str, Any]]):
combos: dict[tuple[int, int], list[dict[str, Any]]] = {}
for record in records:
key = (record["pp_rank"], record["tp_rank"])
combos.setdefault(key, []).append(record)

expected_combo_count = 4
assert (
len(combos) == expected_combo_count
), f"Expected {expected_combo_count} PP/TP combinations, got {len(combos)}"

combo_tensors: dict[tuple[int, int], torch.Tensor] = {}
for (pp_rank, tp_rank), entries in combos.items():
# check that full tensors are the same across data parallel processes
reference = entries[0]["tensor_full"]
seen_dp_ranks: set[int] = set()
for entry in entries:
dp_rank = entry["dp_shard_rank"]
assert dp_rank not in seen_dp_ranks, f"Duplicate DP rank {dp_rank} for combo PP={pp_rank}, TP={tp_rank}"
seen_dp_ranks.add(dp_rank)
assert torch.equal(reference, entry["tensor_full"]), (
"Tensors within the same TP/PP combo must be identical across DP ranks; "
f"mismatch at DP rank {dp_rank} for (PP={pp_rank}, TP={tp_rank})"
)
# concatenate all shards for this pp/tp combo
shards = sorted(entries, key=lambda e: e["dp_shard_rank"])
combo_tensors[(pp_rank, tp_rank)] = torch.cat(
[e["tensor_shard"] for e in shards],
dim=0,
)
# check that tensor shards differ across different pp/tp combos
combo_items = list(combo_tensors.items())
for idx, ((pp_rank, tp_rank), base_tensor) in enumerate(combo_items):
for other_key, other_tensor in combo_items[idx + 1 :]:
tensors_equal = torch.equal(base_tensor, other_tensor)
assert not tensors_equal, (
"Distinct TP/PP combinations should initialize with different weights; "
f"found match between (PP={pp_rank}, TP={tp_rank}) and (PP={other_key[0]}, TP={other_key[1]})"
)

def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degree: int, tmp_path: Path) -> Path:
temp_file_path = tmp_path / "pp_tp_sharding_config.yaml"
working_dir = Path(os.path.dirname(__file__))
config_file_path = (
working_dir / "pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml"
)

with open(config_file_path, "r") as file:
config_string = file.read()
config_dict = yaml.safe_load(config_string)
config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = dp_degree
config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree
config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree

# save to temporary file
with open(temp_file_path, "w") as file:
yaml.dump(config_dict, file)

return temp_file_path