|
| 1 | +import logging |
| 2 | + |
1 | 3 | import torch |
2 | 4 |
|
| 5 | +_CK_STOCHASTIC_ROUNDING_AVAILABLE = False |
| 6 | +try: |
| 7 | + import comfy_kitchen as ck |
| 8 | + _ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8 |
| 9 | + _CK_STOCHASTIC_ROUNDING_AVAILABLE = True |
| 10 | +except (AttributeError, ImportError): |
| 11 | + logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.") |
| 12 | + |
| 13 | +if not _CK_STOCHASTIC_ROUNDING_AVAILABLE: |
| 14 | + def _ck_stochastic_rounding_fp8(value, rng, dtype): |
| 15 | + raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding") |
| 16 | + |
| 17 | + |
3 | 18 | def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): |
4 | 19 | mantissa_scaled = torch.where( |
5 | 20 | normal_mask, |
@@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0): |
57 | 72 | if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: |
58 | 73 | generator = torch.Generator(device=value.device) |
59 | 74 | generator.manual_seed(seed) |
| 75 | + if _CK_STOCHASTIC_ROUNDING_AVAILABLE: |
| 76 | + rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator) |
| 77 | + return _ck_stochastic_rounding_fp8(value, rng, dtype) |
| 78 | + |
60 | 79 | output = torch.empty_like(value, dtype=dtype) |
61 | 80 | num_slices = max(1, (value.numel() / (4096 * 4096))) |
62 | 81 | slice_size = max(1, round(value.shape[0] / num_slices)) |
|
0 commit comments