Skip to content

Commit 6842961

Browse files
authored
float: use CK stochastic rounding cuda kernel (Comfy-Org#13971)
1 parent ade4dfd commit 6842961

1 file changed

Lines changed: 19 additions & 0 deletions

File tree

comfy/float.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
1+
import logging
2+
13
import torch
24

5+
_CK_STOCHASTIC_ROUNDING_AVAILABLE = False
6+
try:
7+
import comfy_kitchen as ck
8+
_ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8
9+
_CK_STOCHASTIC_ROUNDING_AVAILABLE = True
10+
except (AttributeError, ImportError):
11+
logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.")
12+
13+
if not _CK_STOCHASTIC_ROUNDING_AVAILABLE:
14+
def _ck_stochastic_rounding_fp8(value, rng, dtype):
15+
raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding")
16+
17+
318
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
419
mantissa_scaled = torch.where(
520
normal_mask,
@@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0):
5772
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
5873
generator = torch.Generator(device=value.device)
5974
generator.manual_seed(seed)
75+
if _CK_STOCHASTIC_ROUNDING_AVAILABLE:
76+
rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator)
77+
return _ck_stochastic_rounding_fp8(value, rng, dtype)
78+
6079
output = torch.empty_like(value, dtype=dtype)
6180
num_slices = max(1, (value.numel() / (4096 * 4096)))
6281
slice_size = max(1, round(value.shape[0] / num_slices))

0 commit comments

Comments
 (0)