1212import torch .multiprocessing as mp
1313import yaml
1414from pydantic import BaseModel
15+ from torch .distributed ._tensor .placement_types import Replicate
1516
1617from modalities .__main__ import Main
1718from modalities .batch import EvaluationResultBatch
1819from modalities .config .config import ProcessGroupBackendType
1920from modalities .config .pydantic_if_types import PydanticDeviceMeshIFType , PydanticFSDP2ModuleType
2021from modalities .logging_broker .messages import Message
21- from modalities .running_env .fsdp .device_mesh import ParallelismDegrees , get_device_mesh , get_parallel_rank
22+ from modalities .running_env .fsdp .device_mesh import ParallelismDegrees , get_parallel_rank
2223from tests .end2end_tests .custom_components import MultiProcessingCudaEnv
2324from tests .utility import monitor_child_processes
2425
@@ -67,17 +68,6 @@ def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_
6768 os ._exit (1 )
6869
6970 def _seed_distribution_impl (self , world_size : int , tmp_path : Path ):
70- device_mesh = get_device_mesh (
71- device_type = "cuda" ,
72- data_parallel_replicate_degree = 2 ,
73- data_parallel_shard_degree = 1 ,
74- tensor_parallel_degree = 2 ,
75- pipeline_parallel_degree = 2 ,
76- context_parallel_degree = 1 ,
77- enable_loss_parallel = False ,
78- world_size = world_size ,
79- )
80-
8171 # initialize components
8272 class ComponentsInstantiationModel (BaseModel ):
8373 fsdp_model : PydanticFSDP2ModuleType
@@ -88,10 +78,13 @@ class ComponentsInstantiationModel(BaseModel):
8878 components = main_obj .build_components (components_model_type = ComponentsInstantiationModel )
8979 model = components .fsdp_model
9080 device_mesh = components .device_mesh
91- # get first transformer block's MLP weight parameter shards
81+ # for each pp stage get first transformer block's MLP weight parameter shards and full tensor
9282 block_key = next (iter (model .transformer .h .keys ()))
9383 block = model .transformer .h [block_key ]
84+ placements = [Replicate ()] * len (block .mlp .W .weight .device_mesh .mesh .shape )
85+ full_weight = block .mlp .W .weight .redistribute (placements = placements ).to_local ().cpu ()
9486 payload = {
87+ "tensor_full" : full_weight ,
9588 "tensor_shard" : block .mlp .W .weight .to_local ().cpu (),
9689 "tp_rank" : get_parallel_rank (device_mesh = device_mesh , parallelism_method = ParallelismDegrees .TP ),
9790 "pp_rank" : get_parallel_rank (device_mesh = device_mesh , parallelism_method = ParallelismDegrees .PP ),
@@ -121,12 +114,24 @@ def _assert_parameter_distribution(records: list[dict[str, Any]]):
121114
122115 combo_tensors : dict [tuple [int , int ], torch .Tensor ] = {}
123116 for (pp_rank , tp_rank ), entries in combos .items ():
117+ # check that full tensors are the same across data parallel processes
118+ reference = entries [0 ]["tensor_full" ]
119+ seen_dp_ranks : set [int ] = set ()
120+ for entry in entries :
121+ dp_rank = entry ["dp_shard_rank" ]
122+ assert dp_rank not in seen_dp_ranks , f"Duplicate DP rank { dp_rank } for combo PP={ pp_rank } , TP={ tp_rank } "
123+ seen_dp_ranks .add (dp_rank )
124+ assert torch .equal (reference , entry ["tensor_full" ]), (
125+ "Tensors within the same TP/PP combo must be identical across DP ranks; "
126+ f"mismatch at DP rank { dp_rank } for (PP={ pp_rank } , TP={ tp_rank } )"
127+ )
128+ # concatenate all shards for this pp/tp combo
124129 shards = sorted (entries , key = lambda e : e ["dp_shard_rank" ])
125130 combo_tensors [(pp_rank , tp_rank )] = torch .cat (
126131 [e ["tensor_shard" ] for e in shards ],
127132 dim = 0 ,
128133 )
129-
134+ # check that tensor shards differ across different pp/tp combos
130135 combo_items = list (combo_tensors .items ())
131136 for idx , ((pp_rank , tp_rank ), base_tensor ) in enumerate (combo_items ):
132137 for other_key , other_tensor in combo_items [idx + 1 :]:
0 commit comments