From 5e87c38b29e23f9745d63bde8969ce557ac832ab Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 22 Feb 2026 12:25:59 +0530 Subject: [PATCH] remove non-hub attention backends. --- src/diffusers/models/attention_dispatch.py | 742 +-------------------- 1 file changed, 34 insertions(+), 708 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 90ffcac80dc5..ecf1d9dd496d 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -34,12 +34,7 @@ get_logger, is_aiter_available, is_aiter_version, - is_flash_attn_3_available, - is_flash_attn_available, - is_flash_attn_version, is_kernels_available, - is_sageattention_available, - is_sageattention_version, is_torch_npu_available, is_torch_version, is_torch_xla_available, @@ -55,62 +50,23 @@ if TYPE_CHECKING: from ._modeling_parallel import ParallelConfig -_REQUIRED_FLASH_VERSION = "2.6.3" _REQUIRED_AITER_VERSION = "0.1.5" -_REQUIRED_SAGE_VERSION = "2.1.1" _REQUIRED_FLEX_VERSION = "2.5.0" _REQUIRED_XLA_VERSION = "2.2" _REQUIRED_XFORMERS_VERSION = "0.0.29" -_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) -_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() _CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION) -_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) _CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) _CAN_USE_NPU_ATTN = is_torch_npu_available() _CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) -if _CAN_USE_FLASH_ATTN: - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward -else: - flash_attn_func = None - flash_attn_varlen_func = None - _wrapped_flash_attn_backward = None - _wrapped_flash_attn_forward = None - - -if _CAN_USE_FLASH_ATTN_3: - from flash_attn_interface import flash_attn_func as flash_attn_3_func - from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func -else: - flash_attn_3_func = None - flash_attn_3_varlen_func = None - if _CAN_USE_AITER_ATTN: from aiter import flash_attn_func as aiter_flash_attn_func else: aiter_flash_attn_func = None -if _CAN_USE_SAGE_ATTN: - from sageattention import ( - sageattn, - sageattn_qk_int8_pv_fp8_cuda, - sageattn_qk_int8_pv_fp8_cuda_sm90, - sageattn_qk_int8_pv_fp16_cuda, - sageattn_qk_int8_pv_fp16_triton, - sageattn_varlen, - ) -else: - sageattn = None - sageattn_qk_int8_pv_fp16_cuda = None - sageattn_qk_int8_pv_fp16_triton = None - sageattn_qk_int8_pv_fp8_cuda = None - sageattn_qk_int8_pv_fp8_cuda_sm90 = None - sageattn_varlen = None - if _CAN_USE_FLEX_ATTN: # We cannot import the flex_attention function from the package directly because it is expected (from the @@ -136,27 +92,6 @@ else: xops = None -# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 -if torch.__version__ >= "2.4.0": - _custom_op = torch.library.custom_op - _register_fake = torch.library.register_fake -else: - - def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): - def wrap(func): - return func - - return wrap if fn is None else fn - - def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): - def wrap(func): - return func - - return wrap if fn is None else fn - - _custom_op = custom_op_no_op - _register_fake = register_fake_no_op - logger = get_logger(__name__) # pylint: disable=invalid-name @@ -304,11 +239,11 @@ def attention_backend(backend: str | AttentionBackendName = AttentionBackendName """ Context manager to set the active attention backend. """ - if backend not in _AttentionBackendRegistry._backends: - raise ValueError(f"Backend {backend} is not registered.") - backend = AttentionBackendName(backend) _check_attention_backend_requirements(backend) + + if backend not in _AttentionBackendRegistry._backends: + raise ValueError(f"Backend {backend} is not registered.") _maybe_download_kernel_for_backend(backend) old_backend = _AttentionBackendRegistry._active_backend @@ -442,16 +377,32 @@ def _check_shape( def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: - if not _CAN_USE_FLASH_ATTN: - raise RuntimeError( - f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." - ) + raise RuntimeError( + f"The '{backend.value}' attention backend has been removed. " + f"Please use 'flash_hub' or 'flash_varlen_hub' instead, which load the flash-attn kernel from the Hub. " + f"Install the required package with `pip install kernels`." + ) elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: - if not _CAN_USE_FLASH_ATTN_3: - raise RuntimeError( - f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." - ) + raise RuntimeError( + f"The '{backend.value}' attention backend has been removed. " + f"Please use '_flash_3_hub' or '_flash_3_varlen_hub' instead, which load the flash-attn-3 kernel from the Hub. " + f"Install the required package with `pip install kernels`." + ) + + elif backend in [ + AttentionBackendName.SAGE, + AttentionBackendName.SAGE_VARLEN, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, + ]: + raise RuntimeError( + f"The '{backend.value}' attention backend has been removed. " + f"Please use 'sage_hub' instead, which loads the SageAttention kernel from the Hub. " + f"Install the required package with `pip install kernels`." + ) elif backend in [ AttentionBackendName.FLASH_HUB, @@ -471,19 +422,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`." ) - elif backend in [ - AttentionBackendName.SAGE, - AttentionBackendName.SAGE_VARLEN, - AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, - AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, - AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, - AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, - ]: - if not _CAN_USE_SAGE_ATTN: - raise RuntimeError( - f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." - ) - elif backend == AttentionBackendName.FLEX: if not _CAN_USE_FLEX_ATTN: raise RuntimeError( @@ -652,78 +590,6 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: raise -# ===== torch op registrations ===== -# Registrations are required for fullgraph tracing compatibility -# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding -# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 -@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _wrapped_flash_attn_3( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale: float | None = None, - causal: bool = False, - qv: torch.Tensor | None = None, - q_descale: torch.Tensor | None = None, - k_descale: torch.Tensor | None = None, - v_descale: torch.Tensor | None = None, - attention_chunk: int = 0, - softcap: float = 0.0, - num_splits: int = 1, - pack_gqa: bool | None = None, - deterministic: bool = False, - sm_margin: int = 0, -) -> tuple[torch.Tensor, torch.Tensor]: - # Hardcoded for now because pytorch does not support tuple/int type hints - window_size = (-1, -1) - out, lse, *_ = flash_attn_3_func( - q=q, - k=k, - v=v, - softmax_scale=softmax_scale, - causal=causal, - qv=qv, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - window_size=window_size, - attention_chunk=attention_chunk, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - deterministic=deterministic, - sm_margin=sm_margin, - ) - lse = lse.permute(0, 2, 1) - return out, lse - - -@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward") -def _( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - softmax_scale: float | None = None, - causal: bool = False, - qv: torch.Tensor | None = None, - q_descale: torch.Tensor | None = None, - k_descale: torch.Tensor | None = None, - v_descale: torch.Tensor | None = None, - attention_chunk: int = 0, - softcap: float = 0.0, - num_splits: int = 1, - pack_gqa: bool | None = None, - deterministic: bool = False, - sm_margin: int = 0, -) -> tuple[torch.Tensor, torch.Tensor]: - window_size = (-1, -1) # noqa: F841 - # A lot of the parameters here are not yet used in any way within diffusers. - # We can safely ignore for now and keep the fake op shape propagation simple. - batch_size, seq_len, num_heads, head_dim = q.shape - lse_shape = (batch_size, seq_len, num_heads) - return torch.empty_like(q), q.new_empty(lse_shape) - - # ===== Helper functions to use attention backends with templated CP autograd functions ===== @@ -995,107 +861,6 @@ def _native_flash_attention_backward_op( return grad_query, grad_key, grad_value -# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 -def _flash_attention_forward_op( - ctx: torch.autograd.function.FunctionCtx, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: float | None = None, - enable_gqa: bool = False, - return_lse: bool = False, - _save_ctx: bool = True, - _parallel_config: "ParallelConfig" | None = None, -): - if attn_mask is not None: - raise ValueError("`attn_mask` is not yet supported for flash-attn 2.") - if enable_gqa: - raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.") - - # Hardcoded for now - window_size = (-1, -1) - softcap = 0.0 - alibi_slopes = None - deterministic = False - grad_enabled = any(x.requires_grad for x in (query, key, value)) - - if scale is None: - scale = query.shape[-1] ** (-0.5) - - # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround. - if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): - dropout_p = dropout_p if dropout_p > 0 else 1e-30 - - with torch.set_grad_enabled(grad_enabled): - out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward( - query, - key, - value, - dropout_p, - scale, - is_causal, - window_size[0], - window_size[1], - softcap, - alibi_slopes, - return_lse, - ) - lse = lse.permute(0, 2, 1) - - if _save_ctx: - ctx.save_for_backward(query, key, value, out, lse, rng_state) - ctx.dropout_p = dropout_p - ctx.scale = scale - ctx.is_causal = is_causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - - return (out, lse) if return_lse else out - - -def _flash_attention_backward_op( - ctx: torch.autograd.function.FunctionCtx, - grad_out: torch.Tensor, - *args, - **kwargs, -): - query, key, value, out, lse, rng_state = ctx.saved_tensors - grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) - - lse_d = _wrapped_flash_attn_backward( # noqa: F841 - grad_out, - query, - key, - value, - out, - lse, - grad_query, - grad_key, - grad_value, - ctx.dropout_p, - ctx.scale, - ctx.is_causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.alibi_slopes, - ctx.deterministic, - rng_state, - ) - - # Head dimension may have been padded - grad_query = grad_query[..., : grad_out.shape[-1]] - grad_key = grad_key[..., : grad_out.shape[-1]] - grad_value = grad_value[..., : grad_out.shape[-1]] - - return grad_query, grad_key, grad_value - - def _flash_attention_hub_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, @@ -1327,44 +1092,6 @@ def _flash_attention_3_hub_backward_op( return grad_query, grad_key, grad_value -def _sage_attention_forward_op( - ctx: torch.autograd.function.FunctionCtx, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: float | None = None, - enable_gqa: bool = False, - return_lse: bool = False, - _save_ctx: bool = True, - _parallel_config: "ParallelConfig" | None = None, -): - if attn_mask is not None: - raise ValueError("`attn_mask` is not yet supported for Sage attention.") - if dropout_p > 0.0: - raise ValueError("`dropout_p` is not yet supported for Sage attention.") - if enable_gqa: - raise ValueError("`enable_gqa` is not yet supported for Sage attention.") - - out = sageattn( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) - lse = None - if return_lse: - out, lse, *_ = out - lse = lse.permute(0, 2, 1) - - return (out, lse) if return_lse else out - - def _sage_attention_hub_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, @@ -2205,59 +1932,6 @@ def _templated_context_parallel_attention( # ===== Attention backends ===== -@_AttentionBackendRegistry.register( - AttentionBackendName.FLASH, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], - supports_context_parallel=True, -) -def _flash_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: float | None = None, - return_lse: bool = False, - _parallel_config: "ParallelConfig" | None = None, -) -> torch.Tensor: - lse = None - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for flash-attn 2.") - - if _parallel_config is None: - out = flash_attn_func( - q=query, - k=key, - v=value, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - return_attn_probs=return_lse, - ) - if return_lse: - out, lse, *_ = out - else: - out = _templated_context_parallel_attention( - query, - key, - value, - None, - dropout_p, - is_causal, - scale, - False, - return_lse, - forward_op=_flash_attention_forward_op, - backward_op=_flash_attention_backward_op, - _parallel_config=_parallel_config, - ) - if return_lse: - out, lse = out - - return (out, lse) if return_lse else out - - @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], @@ -2370,103 +2044,21 @@ def _flash_varlen_attention_hub( @_AttentionBackendRegistry.register( - AttentionBackendName.FLASH_VARLEN, + AttentionBackendName._FLASH_3_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) -def _flash_varlen_attention( +def _flash_attention_3_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, - dropout_p: float = 0.0, scale: float | None = None, is_causal: bool = False, - return_lse: bool = False, - _parallel_config: "ParallelConfig" | None = None, -) -> torch.Tensor: - batch_size, seq_len_q, _, _ = query.shape - _, seq_len_kv, _, _ = key.shape - - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) - ) - - key_valid, value_valid = [], [] - for b in range(batch_size): - valid_len = seqlens_k[b] - key_valid.append(key[b, :valid_len]) - value_valid.append(value[b, :valid_len]) - - query_packed = query.flatten(0, 1) - key_packed = torch.cat(key_valid, dim=0) - value_packed = torch.cat(value_valid, dim=0) - - out = flash_attn_varlen_func( - q=query_packed, - k=key_packed, - v=value_packed, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - return_attn_probs=return_lse, - ) - out = out.unflatten(0, (batch_size, -1)) - - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._FLASH_3, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _flash_attention_3( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - scale: float | None = None, - is_causal: bool = False, - return_lse: bool = False, - _parallel_config: "ParallelConfig" | None = None, -) -> torch.Tensor: - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for flash-attn 3.") - - out, lse = _wrapped_flash_attn_3( - q=query, - k=key, - v=value, - softmax_scale=scale, - causal=is_causal, - ) - return (out, lse) if return_lse else out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._FLASH_3_HUB, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], - supports_context_parallel=True, -) -def _flash_attention_3_hub( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - scale: float | None = None, - is_causal: bool = False, - window_size: tuple[int, int] = (-1, -1), - softcap: float = 0.0, - deterministic: bool = False, - return_attn_probs: bool = False, + window_size: tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, _parallel_config: "ParallelConfig" | None = None, ) -> torch.Tensor: if attn_mask is not None: @@ -2587,58 +2179,6 @@ def _flash_attention_3_varlen_hub( return (out, lse) if return_lse else out -@_AttentionBackendRegistry.register( - AttentionBackendName._FLASH_VARLEN_3, - constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _flash_varlen_attention_3( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - scale: float | None = None, - is_causal: bool = False, - return_lse: bool = False, - _parallel_config: "ParallelConfig" | None = None, -) -> torch.Tensor: - batch_size, seq_len_q, _, _ = query.shape - _, seq_len_kv, _, _ = key.shape - - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) - ) - - key_valid, value_valid = [], [] - for b in range(batch_size): - valid_len = seqlens_k[b] - key_valid.append(key[b, :valid_len]) - value_valid.append(value[b, :valid_len]) - - query_packed = query.flatten(0, 1) - key_packed = torch.cat(key_valid, dim=0) - value_packed = torch.cat(value_valid, dim=0) - - out, lse, *_ = flash_attn_3_varlen_func( - q=query_packed, - k=key_packed, - v=value_packed, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=scale, - causal=is_causal, - ) - out = out.unflatten(0, (batch_size, -1)) - - return (out, lse) if return_lse else out - - @_AttentionBackendRegistry.register( AttentionBackendName.AITER, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], @@ -3108,57 +2648,6 @@ def _native_xla_attention( return out -@_AttentionBackendRegistry.register( - AttentionBackendName.SAGE, - constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], - supports_context_parallel=True, -) -def _sage_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - is_causal: bool = False, - scale: float | None = None, - return_lse: bool = False, - _parallel_config: "ParallelConfig" | None = None, -) -> torch.Tensor: - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for sage attention") - lse = None - if _parallel_config is None: - out = sageattn( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) - if return_lse: - out, lse, *_ = out - else: - out = _templated_context_parallel_attention( - query, - key, - value, - None, - 0.0, - is_causal, - scale, - False, - return_lse, - forward_op=_sage_attention_forward_op, - backward_op=_sage_attention_backward_op, - _parallel_config=_parallel_config, - ) - if return_lse: - out, lse = out - - return (out, lse) if return_lse else out - - @_AttentionBackendRegistry.register( AttentionBackendName.SAGE_HUB, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], @@ -3211,169 +2700,6 @@ def _sage_attention_hub( return (out, lse) if return_lse else out -@_AttentionBackendRegistry.register( - AttentionBackendName.SAGE_VARLEN, - constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], -) -def _sage_varlen_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - is_causal: bool = False, - scale: float | None = None, - return_lse: bool = False, - _parallel_config: "ParallelConfig" | None = None, -) -> torch.Tensor: - if return_lse: - raise ValueError("Sage varlen backend does not support setting `return_lse=True`.") - - batch_size, seq_len_q, _, _ = query.shape - _, seq_len_kv, _, _ = key.shape - - if attn_mask is not None: - attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) - - (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( - _prepare_for_flash_attn_or_sage_varlen( - batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device - ) - ) - - key_valid, value_valid = [], [] - for b in range(batch_size): - valid_len = seqlens_k[b] - key_valid.append(key[b, :valid_len]) - value_valid.append(value[b, :valid_len]) - - query_packed = query.flatten(0, 1) - key_packed = torch.cat(key_valid, dim=0) - value_packed = torch.cat(value_valid, dim=0) - - out = sageattn_varlen( - q=query_packed, - k=key_packed, - v=value_packed, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - is_causal=is_causal, - sm_scale=scale, - ) - out = out.unflatten(0, (batch_size, -1)) - - return out - - -@_AttentionBackendRegistry.register( - AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, - constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], -) -def _sage_qk_int8_pv_fp8_cuda_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - is_causal: bool = False, - scale: float | None = None, - return_lse: bool = False, - _parallel_config: "ParallelConfig" | None = None, -) -> torch.Tensor: - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for sage attention") - return sageattn_qk_int8_pv_fp8_cuda( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) - - -@_AttentionBackendRegistry.register( - AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, - constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], -) -def _sage_qk_int8_pv_fp8_cuda_sm90_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - is_causal: bool = False, - scale: float | None = None, - return_lse: bool = False, - _parallel_config: "ParallelConfig" | None = None, -) -> torch.Tensor: - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for sage attention") - return sageattn_qk_int8_pv_fp8_cuda_sm90( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) - - -@_AttentionBackendRegistry.register( - AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, - constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], -) -def _sage_qk_int8_pv_fp16_cuda_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - is_causal: bool = False, - scale: float | None = None, - return_lse: bool = False, - _parallel_config: "ParallelConfig" | None = None, -) -> torch.Tensor: - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for sage attention") - return sageattn_qk_int8_pv_fp16_cuda( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) - - -@_AttentionBackendRegistry.register( - AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, - constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], -) -def _sage_qk_int8_pv_fp16_triton_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: torch.Tensor | None = None, - is_causal: bool = False, - scale: float | None = None, - return_lse: bool = False, - _parallel_config: "ParallelConfig" | None = None, -) -> torch.Tensor: - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for sage attention") - return sageattn_qk_int8_pv_fp16_triton( - q=query, - k=key, - v=value, - tensor_layout="NHD", - is_causal=is_causal, - sm_scale=scale, - return_lse=return_lse, - ) - - @_AttentionBackendRegistry.register( AttentionBackendName.XFORMERS, constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],