Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).

"""

Expand All @@ -68,6 +78,7 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
Expand Down Expand Up @@ -124,7 +135,7 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)

self._flattened_mesh = self._mesh._flatten()
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ def enable_parallelism(
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,
Expand Down
105 changes: 99 additions & 6 deletions tests/models/testing_utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()

# Move inputs to device
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}

# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
Expand All @@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()


def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)

dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")

model = model_class(**init_dict)
model.to(device)
model.eval()

inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}

# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)

with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]

if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)

except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()


@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
Expand Down Expand Up @@ -126,3 +174,48 @@ def test_context_parallel_inference(self, cp_type):
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")

if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")

world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}

master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()

mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)

assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
Loading