@@ -79,14 +79,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
7979 return comfy .model_management .cast_to (weight , input .dtype , input .device , non_blocking = non_blocking , copy = copy )
8080
8181
82- def cast_bias_weight_with_vbar (s , dtype , device , bias_dtype , non_blocking , compute_dtype , want_requant ):
82+ def materialize_meta_param (s , param_keys ):
83+ for param_key in param_keys :
84+ param = getattr (s , param_key , None )
85+ if param is not None and getattr (param , "is_meta" , False ):
86+ setattr (s , param_key , torch .nn .Parameter (torch .zeros (param .shape , dtype = param .dtype ), requires_grad = param .requires_grad ))
87+
8388
89+ def cast_bias_weight_with_vbar (s , dtype , device , bias_dtype , non_blocking , compute_dtype , want_requant ):
8490 #vbar doesn't support CPU weights, but some custom nodes have weird paths
8591 #that might switch the layer to the CPU and expect it to work. We have to take
8692 #a clone conservatively as we are mmapped and some SFT files are packed misaligned
8793 #If you are a custom node author reading this, please move your layer to the GPU
8894 #or declare your ModelPatcher as CPU in the first place.
8995 if comfy .model_management .is_device_cpu (device ):
96+ materialize_meta_param (s , ["weight" , "bias" ])
9097 weight = s .weight .to (dtype = dtype , copy = True )
9198 if isinstance (weight , QuantizedTensor ):
9299 weight = weight .dequantize ()
@@ -108,6 +115,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
108115 xfer_dest = comfy_aimdo .torch .aimdo_to_tensor (s ._v , device )
109116
110117 if not resident :
118+ materialize_meta_param (s , ["weight" , "bias" ])
111119 cast_geometry = comfy .memory_management .tensors_to_geometries ([ s .weight , s .bias ])
112120 cast_dest = None
113121
@@ -306,6 +314,12 @@ class CastWeightBiasOp:
306314 bias_function = []
307315
308316class disable_weight_init :
317+ @staticmethod
318+ def _zero_init_parameter (module , name ):
319+ param = getattr (module , name )
320+ device = None if getattr (param , "is_meta" , False ) else param .device
321+ setattr (module , name , torch .nn .Parameter (torch .zeros (param .shape , device = device , dtype = param .dtype ), requires_grad = False ))
322+
309323 @staticmethod
310324 def _lazy_load_from_state_dict (module , state_dict , prefix , local_metadata ,
311325 missing_keys , unexpected_keys , weight_shape ,
0 commit comments