Skip to content

Commit ec4b165

Browse files
authored
ModelPatcherDynamic: force cast stray weights on comfy layers (Comfy-Org#13487)
the mixed_precision ops can have input_scale parameters that are used in tensor math but arent a weight or bias so dont get proper VRAM management. Treat these as force-castable parameters like the non comfy weight, random params are buffers already are.
1 parent cb388e2 commit ec4b165

1 file changed

Lines changed: 11 additions & 6 deletions

File tree

comfy/model_patcher.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -685,27 +685,28 @@ def model_state_dict(self, filter_prefix=None):
685685
sd.pop(k)
686686
return sd
687687

688-
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
688+
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False):
689689
weight, set_func, convert_func = get_key_weight(self.model, key)
690-
if key not in self.patches:
690+
if key not in self.patches and not force_cast:
691691
return weight
692692

693693
inplace_update = self.weight_inplace_update or inplace_update
694694

695695
if key not in self.backup and not return_weight:
696696
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
697697

698-
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
698+
temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None
699699
if device_to is not None:
700700
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
701701
else:
702702
temp_weight = weight.to(temp_dtype, copy=True)
703703
if convert_func is not None:
704704
temp_weight = convert_func(temp_weight, inplace=True)
705705

706-
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
706+
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight
707707
if set_func is None:
708-
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
708+
if key in self.patches:
709+
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
709710
if return_weight:
710711
return out_weight
711712
elif inplace_update:
@@ -1584,7 +1585,7 @@ def force_load_param(self, param_key, device_to):
15841585
key = key_param_name_to_key(n, param_key)
15851586
if key in self.backup:
15861587
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
1587-
self.patch_weight_to_device(key, device_to=device_to)
1588+
self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
15881589
weight, _, _ = get_key_weight(self.model, key)
15891590
if weight is not None:
15901591
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
@@ -1609,6 +1610,10 @@ def force_load_param(self, param_key, device_to):
16091610
m._v = vbar.alloc(v_weight_size)
16101611
allocated_size += v_weight_size
16111612

1613+
for param in params:
1614+
if param not in ("weight", "bias"):
1615+
force_load_param(self, param, device_to)
1616+
16121617
else:
16131618
for param in params:
16141619
key = key_param_name_to_key(n, param)

0 commit comments

Comments
 (0)