Skip to content

Commit c58c13b

Browse files
Fix torch compile regression on fp8 ops. (Comfy-Org#10580)
1 parent 7f374e4 commit c58c13b

4 files changed

Lines changed: 43 additions & 36 deletions

File tree

comfy/ops.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -401,15 +401,9 @@ def fp8_linear(self, input):
401401
if dtype not in [torch.float8_e4m3fn]:
402402
return None
403403

404-
tensor_2d = False
405-
if len(input.shape) == 2:
406-
tensor_2d = True
407-
input = input.unsqueeze(1)
408-
409-
input_shape = input.shape
410404
input_dtype = input.dtype
411405

412-
if len(input.shape) == 3:
406+
if input.ndim == 3 or input.ndim == 2:
413407
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
414408

415409
scale_weight = self.scale_weight
@@ -422,24 +416,20 @@ def fp8_linear(self, input):
422416
if scale_input is None:
423417
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
424418
input = torch.clamp(input, min=-448, max=448, out=input)
425-
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
426419
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
427-
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
420+
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
428421
else:
429422
scale_input = scale_input.to(input.device)
430-
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
423+
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
431424

432425
# Wrap weight in QuantizedTensor - this enables unified dispatch
433426
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
434427
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
435-
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
428+
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
436429
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
437430

438431
uncast_bias_weight(self, w, bias, offload_stream)
439-
440-
if tensor_2d:
441-
return o.reshape(input_shape[0], -1)
442-
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
432+
return o
443433

444434
return None
445435

@@ -540,12 +530,12 @@ def forward(self, *args, **kwargs):
540530
# ==============================================================================
541531
# Mixed Precision Operations
542532
# ==============================================================================
543-
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
533+
from .quant_ops import QuantizedTensor
544534

545535
QUANT_FORMAT_MIXINS = {
546536
"float8_e4m3fn": {
547537
"dtype": torch.float8_e4m3fn,
548-
"layout_type": TensorCoreFP8Layout,
538+
"layout_type": "TensorCoreFP8Layout",
549539
"parameters": {
550540
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
551541
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),

comfy/quant_ops.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __new__(cls, qdata, layout_type, layout_params):
123123
layout_type: Layout class (subclass of QuantizedLayout)
124124
layout_params: Dict with layout-specific parameters
125125
"""
126-
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
126+
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
127127

128128
def __init__(self, qdata, layout_type, layout_params):
129129
self._qdata = qdata.contiguous()
@@ -183,11 +183,11 @@ def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
183183

184184
@classmethod
185185
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
186-
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
186+
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
187187
return cls(qdata, layout_type, layout_params)
188188

189189
def dequantize(self) -> torch.Tensor:
190-
return self._layout_type.dequantize(self._qdata, **self._layout_params)
190+
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
191191

192192
@classmethod
193193
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@@ -379,7 +379,12 @@ def get_plain_tensors(cls, qtensor):
379379
return qtensor._qdata, qtensor._layout_params['scale']
380380

381381

382-
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
382+
LAYOUTS = {
383+
"TensorCoreFP8Layout": TensorCoreFP8Layout,
384+
}
385+
386+
387+
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
383388
def fp8_linear(func, args, kwargs):
384389
input_tensor = args[0]
385390
weight = args[1]
@@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs):
422427
'scale': output_scale,
423428
'orig_dtype': input_tensor._layout_params['orig_dtype']
424429
}
425-
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
430+
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
426431
else:
427432
return output
428433

@@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs):
436441
input_tensor = input_tensor.dequantize()
437442

438443
return torch.nn.functional.linear(input_tensor, weight, bias)
444+
445+
446+
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
447+
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
448+
def fp8_func(func, args, kwargs):
449+
input_tensor = args[0]
450+
if isinstance(input_tensor, QuantizedTensor):
451+
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
452+
ar = list(args)
453+
ar[0] = plain_input
454+
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
455+
return func(*args, **kwargs)

tests-unit/comfy_quant/test_mixed_precision.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def has_gpu():
1414
args.cpu = True
1515

1616
from comfy import ops
17-
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
17+
from comfy.quant_ops import QuantizedTensor
1818

1919

2020
class SimpleModel(torch.nn.Module):
@@ -104,14 +104,14 @@ def test_mixed_precision_load(self):
104104

105105
# Verify weights are wrapped in QuantizedTensor
106106
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
107-
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
107+
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
108108

109109
# Layer 2 should NOT be quantized
110110
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
111111

112112
# Layer 3 should be quantized
113113
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
114-
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
114+
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
115115

116116
# Verify scales were loaded
117117
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
@@ -155,7 +155,7 @@ def test_state_dict_quantized_preserved(self):
155155
# Verify layer1.weight is a QuantizedTensor with scale preserved
156156
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
157157
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
158-
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
158+
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
159159

160160
# Verify non-quantized layers are standard tensors
161161
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)

tests-unit/comfy_quant/test_quant_registry.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ def test_creation(self):
2525
scale = torch.tensor(2.0)
2626
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
2727

28-
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
28+
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
2929

3030
self.assertIsInstance(qt, QuantizedTensor)
3131
self.assertEqual(qt.shape, (256, 128))
3232
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
3333
self.assertEqual(qt._layout_params['scale'], scale)
3434
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
35-
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
35+
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
3636

3737
def test_dequantize(self):
3838
"""Test explicit dequantization"""
@@ -41,7 +41,7 @@ def test_dequantize(self):
4141
scale = torch.tensor(3.0)
4242
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
4343

44-
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
44+
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
4545
dequantized = qt.dequantize()
4646

4747
self.assertEqual(dequantized.dtype, torch.float32)
@@ -54,7 +54,7 @@ def test_from_float(self):
5454

5555
qt = QuantizedTensor.from_float(
5656
float_tensor,
57-
TensorCoreFP8Layout,
57+
"TensorCoreFP8Layout",
5858
scale=scale,
5959
dtype=torch.float8_e4m3fn
6060
)
@@ -77,28 +77,28 @@ def test_detach(self):
7777
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
7878
scale = torch.tensor(1.5)
7979
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
80-
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
80+
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
8181

8282
# Detach should return a new QuantizedTensor
8383
qt_detached = qt.detach()
8484

8585
self.assertIsInstance(qt_detached, QuantizedTensor)
8686
self.assertEqual(qt_detached.shape, qt.shape)
87-
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
87+
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
8888

8989
def test_clone(self):
9090
"""Test clone operation on quantized tensor"""
9191
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
9292
scale = torch.tensor(1.5)
9393
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
94-
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
94+
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
9595

9696
# Clone should return a new QuantizedTensor
9797
qt_cloned = qt.clone()
9898

9999
self.assertIsInstance(qt_cloned, QuantizedTensor)
100100
self.assertEqual(qt_cloned.shape, qt.shape)
101-
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
101+
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
102102

103103
# Verify it's a deep copy
104104
self.assertIsNot(qt_cloned._qdata, qt._qdata)
@@ -109,7 +109,7 @@ def test_to_device(self):
109109
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
110110
scale = torch.tensor(1.5)
111111
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
112-
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
112+
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
113113

114114
# Moving to same device should work (CPU to CPU)
115115
qt_cpu = qt.to('cpu')
@@ -169,7 +169,7 @@ def test_unsupported_op_dequantizes(self):
169169
scale = torch.tensor(1.0)
170170
a_q = QuantizedTensor.from_float(
171171
a_fp32,
172-
TensorCoreFP8Layout,
172+
"TensorCoreFP8Layout",
173173
scale=scale,
174174
dtype=torch.float8_e4m3fn
175175
)

0 commit comments

Comments
 (0)