|
23 | 23 | import comfy.float |
24 | 24 | import comfy.rmsnorm |
25 | 25 | import contextlib |
26 | | -from torch.nn.attention import SDPBackend, sdpa_kernel |
27 | 26 |
|
28 | | -cast_to = comfy.model_management.cast_to #TODO: remove once no more references |
29 | 27 |
|
30 | | -SDPA_BACKEND_PRIORITY = [ |
31 | | - SDPBackend.FLASH_ATTENTION, |
32 | | - SDPBackend.EFFICIENT_ATTENTION, |
33 | | - SDPBackend.MATH, |
34 | | -] |
35 | | -if torch.cuda.is_available(): |
36 | | - SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) |
| 28 | +def scaled_dot_product_attention(q, k, v, *args, **kwargs): |
| 29 | + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) |
| 30 | + |
| 31 | + |
| 32 | +try: |
| 33 | + if torch.cuda.is_available(): |
| 34 | + from torch.nn.attention import SDPBackend, sdpa_kernel |
| 35 | + |
| 36 | + SDPA_BACKEND_PRIORITY = [ |
| 37 | + SDPBackend.FLASH_ATTENTION, |
| 38 | + SDPBackend.EFFICIENT_ATTENTION, |
| 39 | + SDPBackend.MATH, |
| 40 | + ] |
| 41 | + |
| 42 | + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) |
| 43 | + |
| 44 | + @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) |
| 45 | + def scaled_dot_product_attention(q, k, v, *args, **kwargs): |
| 46 | + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) |
| 47 | +except (ModuleNotFoundError, TypeError): |
| 48 | + logging.warning("Could not set sdpa backend priority.") |
| 49 | + |
| 50 | +cast_to = comfy.model_management.cast_to #TODO: remove once no more references |
37 | 51 |
|
38 | 52 | def cast_to_input(weight, input, non_blocking=False, copy=True): |
39 | 53 | return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) |
@@ -258,10 +272,6 @@ def conv_nd(s, dims, *args, **kwargs): |
258 | 272 | else: |
259 | 273 | raise ValueError(f"unsupported dimensions: {dims}") |
260 | 274 |
|
261 | | - @staticmethod |
262 | | - @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) |
263 | | - def scaled_dot_product_attention(q, k, v, *args, **kwargs): |
264 | | - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) |
265 | 275 |
|
266 | 276 | class manual_cast(disable_weight_init): |
267 | 277 | class Linear(disable_weight_init.Linear): |
|
0 commit comments