@@ -685,27 +685,28 @@ def model_state_dict(self, filter_prefix=None):
685685 sd .pop (k )
686686 return sd
687687
688- def patch_weight_to_device (self , key , device_to = None , inplace_update = False , return_weight = False ):
688+ def patch_weight_to_device (self , key , device_to = None , inplace_update = False , return_weight = False , force_cast = False ):
689689 weight , set_func , convert_func = get_key_weight (self .model , key )
690- if key not in self .patches :
690+ if key not in self .patches and not force_cast :
691691 return weight
692692
693693 inplace_update = self .weight_inplace_update or inplace_update
694694
695695 if key not in self .backup and not return_weight :
696696 self .backup [key ] = collections .namedtuple ('Dimension' , ['weight' , 'inplace_update' ])(weight .to (device = self .offload_device , copy = inplace_update ), inplace_update )
697697
698- temp_dtype = comfy .model_management .lora_compute_dtype (device_to )
698+ temp_dtype = comfy .model_management .lora_compute_dtype (device_to ) if key in self . patches else None
699699 if device_to is not None :
700700 temp_weight = comfy .model_management .cast_to_device (weight , device_to , temp_dtype , copy = True )
701701 else :
702702 temp_weight = weight .to (temp_dtype , copy = True )
703703 if convert_func is not None :
704704 temp_weight = convert_func (temp_weight , inplace = True )
705705
706- out_weight = comfy .lora .calculate_weight (self .patches [key ], temp_weight , key )
706+ out_weight = comfy .lora .calculate_weight (self .patches [key ], temp_weight , key ) if key in self . patches else temp_weight
707707 if set_func is None :
708- out_weight = comfy .float .stochastic_rounding (out_weight , weight .dtype , seed = comfy .utils .string_to_seed (key ))
708+ if key in self .patches :
709+ out_weight = comfy .float .stochastic_rounding (out_weight , weight .dtype , seed = comfy .utils .string_to_seed (key ))
709710 if return_weight :
710711 return out_weight
711712 elif inplace_update :
@@ -1584,7 +1585,7 @@ def force_load_param(self, param_key, device_to):
15841585 key = key_param_name_to_key (n , param_key )
15851586 if key in self .backup :
15861587 comfy .utils .set_attr_param (self .model , key , self .backup [key ].weight )
1587- self .patch_weight_to_device (key , device_to = device_to )
1588+ self .patch_weight_to_device (key , device_to = device_to , force_cast = True )
15881589 weight , _ , _ = get_key_weight (self .model , key )
15891590 if weight is not None :
15901591 self .model .model_loaded_weight_memory += weight .numel () * weight .element_size ()
@@ -1609,6 +1610,10 @@ def force_load_param(self, param_key, device_to):
16091610 m ._v = vbar .alloc (v_weight_size )
16101611 allocated_size += v_weight_size
16111612
1613+ for param in params :
1614+ if param not in ("weight" , "bias" ):
1615+ force_load_param (self , param , device_to )
1616+
16121617 else :
16131618 for param in params :
16141619 key = key_param_name_to_key (n , param )
0 commit comments