diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index ed966dc8fe98..8573c01ca4c7 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -60,6 +60,16 @@ class ContextParallelConfig: rotate_method (`str`, *optional*, defaults to `"allgather"`): Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"` is supported. + ulysses_anything (`bool`, *optional*, defaults to `False`): + Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that + are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and + `ring_degree` must be 1. + mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): + A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of + creating a new one. This is useful when combining context parallelism with other parallelism strategies + (e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and + "ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with + `mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP). """ @@ -68,6 +78,7 @@ class ContextParallelConfig: convert_to_fp32: bool = True # TODO: support alltoall rotate_method: Literal["allgather", "alltoall"] = "allgather" + mesh: torch.distributed.device_mesh.DeviceMesh | None = None # Whether to enable ulysses anything attention to support # any sequence lengths and any head numbers. ulysses_anything: bool = False @@ -124,7 +135,7 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})." ) - self._flattened_mesh = self._mesh._flatten() + self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten() self._ring_mesh = self._mesh["ring"] self._ulysses_mesh = self._mesh["ulysses"] self._ring_local_rank = self._ring_mesh.get_local_rank() diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0901840679e3..401074050333 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1567,7 +1567,7 @@ def enable_parallelism( mesh = None if config.context_parallel_config is not None: cp_config = config.context_parallel_config - mesh = torch.distributed.device_mesh.init_device_mesh( + mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh( device_type=device_type, mesh_shape=cp_config.mesh_shape, mesh_dim_names=cp_config.mesh_dim_names, diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index e05b36799e66..3858acf71ec5 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di model.eval() # Move inputs to device - inputs_on_device = {} - for key, value in inputs_dict.items(): - if isinstance(value, torch.Tensor): - inputs_on_device[key] = value.to(device) - else: - inputs_on_device[key] = value + inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} # Enable context parallelism cp_config = ContextParallelConfig(**cp_dict) @@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di dist.destroy_process_group() +def _custom_mesh_worker( + rank, + world_size, + master_port, + model_class, + init_dict, + cp_dict, + mesh_shape, + mesh_dim_names, + inputs_dict, + return_dict, +): + """Worker function for context parallel testing with a user-provided custom DeviceMesh.""" + try: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + + model = model_class(**init_dict) + model.to(device) + model.eval() + + inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} + + # DeviceMesh must be created after init_process_group, inside each worker process. + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names + ) + cp_config = ContextParallelConfig(**cp_dict, mesh=mesh) + model.enable_parallelism(config=cp_config) + + with torch.no_grad(): + output = model(**inputs_on_device, return_dict=False)[0] + + if rank == 0: + return_dict["status"] = "success" + return_dict["output_shape"] = list(output.shape) + + except Exception as e: + if rank == 0: + return_dict["status"] = "error" + return_dict["error"] = str(e) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + @is_context_parallel @require_torch_multi_accelerator class ContextParallelTesterMixin: @@ -126,3 +174,48 @@ def test_context_parallel_inference(self, cp_type): assert return_dict.get("status") == "success", ( f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) + + @pytest.mark.parametrize( + "cp_type,mesh_shape,mesh_dim_names", + [ + ("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")), + ("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")), + ], + ids=["ring-3d-fsdp", "ulysses-3d-fsdp"], + ) + def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names): + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()} + cp_dict = {cp_type: world_size} + + master_port = _find_free_port() + manager = mp.Manager() + return_dict = manager.dict() + + mp.spawn( + _custom_mesh_worker, + args=( + world_size, + master_port, + self.model_class, + init_dict, + cp_dict, + mesh_shape, + mesh_dim_names, + inputs_dict, + return_dict, + ), + nprocs=world_size, + join=True, + ) + + assert return_dict.get("status") == "success", ( + f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}" + )