Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,9 +1865,12 @@ def forward(
out = out.to(torch.float32)
lse = lse.to(torch.float32)

# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
# lse must be 4-D to broadcast with out (B, S, H, D).
# Some backends (e.g. cuDNN on torch>=2.9) already return a
# trailing-1 dim; others (e.g. flash-hub / native-flash) always
# return 3-D lse, so we add the dim here when needed.
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if lse.ndim == 3:
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
Expand Down Expand Up @@ -2154,10 +2157,11 @@ def _templated_unified_attention(
scatter_idx,
)
if return_lse:
# lse is of shape (B, S, H_LOCAL, 1)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
# lse from TemplatedRingAttention is 3-D (B, S, H_LOCAL) after its
# final squeeze(-1). SeqAllToAllDim requires a 4-D input, so we add
# the trailing dim here and remove it after the collective.
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if lse.ndim == 3:
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
lse = lse.squeeze(-1)
Expand Down
3 changes: 2 additions & 1 deletion tests/models/testing_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .parallelism import ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin
from .quantization import (
BitsAndBytesCompileTesterMixin,
BitsAndBytesConfigMixin,
Expand Down Expand Up @@ -45,6 +45,7 @@
"BitsAndBytesTesterMixin",
"CacheTesterMixin",
"ContextParallelTesterMixin",
"ContextParallelAttentionBackendsTesterMixin",
"CPUOffloadTesterMixin",
"FasterCacheConfigMixin",
"FasterCacheTesterMixin",
Expand Down
93 changes: 88 additions & 5 deletions tests/models/testing_utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@

from diffusers.models._modeling_parallel import ContextParallelConfig

from ...testing_utils import (
is_context_parallel,
require_torch_multi_accelerator,
)
from ...testing_utils import is_context_parallel, is_kernels_available, require_torch_multi_accelerator
from .utils import _maybe_cast_to_bf16


def _find_free_port():
Expand All @@ -38,7 +36,9 @@ def _find_free_port():
return port


def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict):
def _context_parallel_worker(
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None
):
"""Worker function for context parallel testing."""
try:
# Set up distributed environment
Expand All @@ -59,6 +59,9 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.to(device)
model.eval()

# Cast as needed.
model, inputs_dict = _maybe_cast_to_bf16(attention_backend, model, inputs_dict)

# Move inputs to device
inputs_on_device = {}
for key, value in inputs_dict.items():
Expand All @@ -67,6 +70,13 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
else:
inputs_on_device[key] = value

# Enable attention backend
if attention_backend:
try:
model.set_attention_backend(attention_backend)
except Exception as e:
pytest.skip(f"Skipping test because of exception: {e}.")

# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
Expand Down Expand Up @@ -126,3 +136,76 @@ def test_context_parallel_inference(self, cp_type):
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)


@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelAttentionBackendsTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"])
@pytest.mark.parametrize(
"attention_backend",
[
"native",
pytest.param(
"flash_hub",
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
),
pytest.param(
"_flash_3_hub",
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
),
],
)
@pytest.mark.parametrize("ulysses_anything", [True, False])
@torch.no_grad()
def test_context_parallel_attn_backend_inference(self, cp_type, attention_backend, ulysses_anything):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")

if getattr(self.model_class, "_cp_plan", None) is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")

if cp_type == "ring_degree":
if attention_backend == "native":
pytest.skip("Skipping test because ulysses isn't supported with native attention backend.")

if ulysses_anything and "ulysses" not in cp_type:
pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.")

world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()

# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
cp_dict = {cp_type: world_size}
if ulysses_anything:
cp_dict.update({"ulysses_anything": ulysses_anything})

# Find a free port for distributed communication
master_port = _find_free_port()

# Use multiprocessing manager for cross-process communication
manager = mp.Manager()
return_dict = manager.dict()

# Spawn worker processes
mp.spawn(
_context_parallel_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
inputs_dict,
return_dict,
attention_backend,
),
nprocs=world_size,
join=True,
)

assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
22 changes: 22 additions & 0 deletions tests/models/testing_utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch

from diffusers.models.attention_dispatch import AttentionBackendName


_BF16_REQUIRED_BACKENDS = {
AttentionBackendName._NATIVE_CUDNN,
AttentionBackendName.FLASH_HUB,
AttentionBackendName._FLASH_3_HUB,
}


def _maybe_cast_to_bf16(backend, model, inputs_dict):
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
if not backend or backend not in _BF16_REQUIRED_BACKENDS:
return model, inputs_dict
model = model.to(dtype=torch.bfloat16)
inputs_dict = {
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs_dict.items()
}
return model, inputs_dict
7 changes: 7 additions & 0 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
BitsAndBytesTesterMixin,
ContextParallelAttentionBackendsTesterMixin,
ContextParallelTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
Expand Down Expand Up @@ -228,6 +229,12 @@ class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextPar
"""Context Parallel inference tests for Flux Transformer"""


class TestFluxTransformerContextParallelAttnBackends(
FluxTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin
):
"""Context Parallel inference x attention backends tests for Flux Transformer"""


class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""

Expand Down
11 changes: 10 additions & 1 deletion utils/generate_model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
# Other testers
("SingleFileTesterMixin", "single_file"),
("IPAdapterTesterMixin", "ip_adapter"),
("ContextParallelAttentionBackendsTesterMixin", "cp_attn"),
]


Expand Down Expand Up @@ -229,7 +230,14 @@ def determine_testers(model_info: dict, include_optional: list[str], imports: se

for tester, flag in OPTIONAL_TESTERS:
if flag in include_optional:
if tester not in testers:
if tester == "ContextParallelAttentionBackendsTesterMixin":
if (
"cp_attn" in include_optional
and "_cp_plan" in model_info["attributes"]
and model_info["attributes"]["_cp_plan"] is not None
):
testers.append(tester)
elif tester not in testers:
testers.append(tester)

return testers
Expand Down Expand Up @@ -530,6 +538,7 @@ def main():
"faster_cache",
"single_file",
"ip_adapter",
"cp_attn",
"all",
],
help="Optional testers to include",
Expand Down
Loading