Skip to content

Commit b47f15f

Browse files
authored
fix: Handle un-inited meta-tensors in models (fixes a CPU TE crash) (CORE-67) (Comfy-Org#13578)
1 parent 3cbf015 commit b47f15f

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

comfy/model_patcher.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import comfy.hooks
3232
import comfy.lora
3333
import comfy.model_management
34+
import comfy.ops
3435
import comfy.patcher_extension
3536
import comfy.utils
3637
from comfy.comfy_types import UnetWrapperFunction
@@ -856,7 +857,9 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
856857
if m.comfy_patched_weights == True:
857858
continue
858859

859-
for param in params:
860+
for param, param_value in params.items():
861+
if hasattr(m, "comfy_cast_weights") and getattr(param_value, "is_meta", False):
862+
comfy.ops.disable_weight_init._zero_init_parameter(m, param)
860863
key = key_param_name_to_key(n, param)
861864
self.unpin_weight(key)
862865
self.patch_weight_to_device(key, device_to=device_to)

comfy/ops.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
7979
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
8080

8181

82-
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
82+
def materialize_meta_param(s, param_keys):
83+
for param_key in param_keys:
84+
param = getattr(s, param_key, None)
85+
if param is not None and getattr(param, "is_meta", False):
86+
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
87+
8388

89+
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
8490
#vbar doesn't support CPU weights, but some custom nodes have weird paths
8591
#that might switch the layer to the CPU and expect it to work. We have to take
8692
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
8793
#If you are a custom node author reading this, please move your layer to the GPU
8894
#or declare your ModelPatcher as CPU in the first place.
8995
if comfy.model_management.is_device_cpu(device):
96+
materialize_meta_param(s, ["weight", "bias"])
9097
weight = s.weight.to(dtype=dtype, copy=True)
9198
if isinstance(weight, QuantizedTensor):
9299
weight = weight.dequantize()
@@ -108,6 +115,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
108115
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
109116

110117
if not resident:
118+
materialize_meta_param(s, ["weight", "bias"])
111119
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
112120
cast_dest = None
113121

@@ -306,6 +314,12 @@ class CastWeightBiasOp:
306314
bias_function = []
307315

308316
class disable_weight_init:
317+
@staticmethod
318+
def _zero_init_parameter(module, name):
319+
param = getattr(module, name)
320+
device = None if getattr(param, "is_meta", False) else param.device
321+
setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False))
322+
309323
@staticmethod
310324
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
311325
missing_keys, unexpected_keys, weight_shape,

0 commit comments

Comments
 (0)