Skip to content

Commit 00a595b

Browse files
committed
test: Check for equal parameters across data parallel processes
1 parent 62a1743 commit 00a595b

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

tests/fsdp2_parallelization/test_parallel_seed_initialization.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
import torch.multiprocessing as mp
1313
import yaml
1414
from pydantic import BaseModel
15+
from torch.distributed._tensor.placement_types import Replicate
1516

1617
from modalities.__main__ import Main
1718
from modalities.batch import EvaluationResultBatch
1819
from modalities.config.config import ProcessGroupBackendType
1920
from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType
2021
from modalities.logging_broker.messages import Message
21-
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_device_mesh, get_parallel_rank
22+
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank
2223
from tests.end2end_tests.custom_components import MultiProcessingCudaEnv
2324
from tests.utility import monitor_child_processes
2425

@@ -67,17 +68,6 @@ def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_
6768
os._exit(1)
6869

6970
def _seed_distribution_impl(self, world_size: int, tmp_path: Path):
70-
device_mesh = get_device_mesh(
71-
device_type="cuda",
72-
data_parallel_replicate_degree=2,
73-
data_parallel_shard_degree=1,
74-
tensor_parallel_degree=2,
75-
pipeline_parallel_degree=2,
76-
context_parallel_degree=1,
77-
enable_loss_parallel=False,
78-
world_size=world_size,
79-
)
80-
8171
# initialize components
8272
class ComponentsInstantiationModel(BaseModel):
8373
fsdp_model: PydanticFSDP2ModuleType
@@ -88,10 +78,13 @@ class ComponentsInstantiationModel(BaseModel):
8878
components = main_obj.build_components(components_model_type=ComponentsInstantiationModel)
8979
model = components.fsdp_model
9080
device_mesh = components.device_mesh
91-
# get first transformer block's MLP weight parameter shards
81+
# for each pp stage get first transformer block's MLP weight parameter shards and full tensor
9282
block_key = next(iter(model.transformer.h.keys()))
9383
block = model.transformer.h[block_key]
84+
placements = [Replicate()] * len(block.mlp.W.weight.device_mesh.mesh.shape)
85+
full_weight = block.mlp.W.weight.redistribute(placements=placements).to_local().cpu()
9486
payload = {
87+
"tensor_full": full_weight,
9588
"tensor_shard": block.mlp.W.weight.to_local().cpu(),
9689
"tp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP),
9790
"pp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP),
@@ -121,12 +114,24 @@ def _assert_parameter_distribution(records: list[dict[str, Any]]):
121114

122115
combo_tensors: dict[tuple[int, int], torch.Tensor] = {}
123116
for (pp_rank, tp_rank), entries in combos.items():
117+
# check that full tensors are the same across data parallel processes
118+
reference = entries[0]["tensor_full"]
119+
seen_dp_ranks: set[int] = set()
120+
for entry in entries:
121+
dp_rank = entry["dp_shard_rank"]
122+
assert dp_rank not in seen_dp_ranks, f"Duplicate DP rank {dp_rank} for combo PP={pp_rank}, TP={tp_rank}"
123+
seen_dp_ranks.add(dp_rank)
124+
assert torch.equal(reference, entry["tensor_full"]), (
125+
"Tensors within the same TP/PP combo must be identical across DP ranks; "
126+
f"mismatch at DP rank {dp_rank} for (PP={pp_rank}, TP={tp_rank})"
127+
)
128+
# concatenate all shards for this pp/tp combo
124129
shards = sorted(entries, key=lambda e: e["dp_shard_rank"])
125130
combo_tensors[(pp_rank, tp_rank)] = torch.cat(
126131
[e["tensor_shard"] for e in shards],
127132
dim=0,
128133
)
129-
134+
# check that tensor shards differ across different pp/tp combos
130135
combo_items = list(combo_tensors.items())
131136
for idx, ((pp_rank, tp_rank), base_tensor) in enumerate(combo_items):
132137
for other_key, other_tensor in combo_items[idx + 1 :]:

0 commit comments

Comments
 (0)