Skip to content

Commit 16f862f

Browse files
authored
implement dynamic clip saving (Comfy-Org#13959)
Fix clip saving by doing the same patching process and diffusion models.
1 parent d4c6c9e commit 16f862f

3 files changed

Lines changed: 28 additions & 14 deletions

File tree

comfy/model_patcher.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,38 +1493,45 @@ def clean_hooks(self):
14931493
self.unpatch_hooks()
14941494
self.clear_cached_hook_weights()
14951495

1496-
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
1497-
original_state_dict = self.model.diffusion_model.state_dict()
1498-
unet_state_dict = {}
1496+
def model_state_dict_for_saving(self, model=None, prefix=""):
1497+
if model is None:
1498+
model = self.model
1499+
1500+
original_state_dict = model.state_dict()
1501+
output_state_dict = {}
14991502
keys = list(original_state_dict)
15001503
while len(keys) > 0:
15011504
k = keys.pop(0)
15021505
v = original_state_dict[k]
15031506
op_keys = k.rsplit('.', 1)
15041507
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
1505-
unet_state_dict[k] = v
1508+
output_state_dict[k] = v
15061509
continue
15071510
try:
1508-
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
1511+
op = comfy.utils.get_attr(model, op_keys[0])
15091512
except:
1510-
unet_state_dict[k] = v
1513+
output_state_dict[k] = v
15111514
continue
15121515
if not op or not hasattr(op, "comfy_cast_weights") or \
15131516
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
1514-
unet_state_dict[k] = v
1517+
output_state_dict[k] = v
15151518
continue
1516-
key = "diffusion_model." + k
1519+
key = prefix + k
15171520
weight = comfy.utils.get_attr(self.model, key)
15181521
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
15191522
qt_state_dict = weight.state_dict(k)
15201523
caster = LazyCastingQuantizedParam(self, key)
15211524
for group_key in (x for x in qt_state_dict if x in original_state_dict):
15221525
if group_key in keys:
15231526
keys.remove(group_key)
1524-
unet_state_dict.pop(group_key, "")
1525-
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
1527+
output_state_dict.pop(group_key, "")
1528+
output_state_dict[group_key] = LazyCastingParamPiece(caster, prefix + group_key, original_state_dict[group_key])
15261529
continue
1527-
unet_state_dict[k] = LazyCastingParam(self, key, weight)
1530+
output_state_dict[k] = LazyCastingParam(self, key, weight)
1531+
return output_state_dict
1532+
1533+
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
1534+
unet_state_dict = self.model_state_dict_for_saving(self.model.diffusion_model, "diffusion_model.")
15281535
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
15291536

15301537
def __del__(self):

comfy/sd.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,13 @@ def get_sd(self):
423423
sd_clip[k] = sd_tokenizer[k]
424424
return sd_clip
425425

426+
def state_dict_for_saving(self):
427+
sd_clip = self.patcher.model_state_dict_for_saving()
428+
sd_tokenizer = self.tokenizer.state_dict()
429+
for k in sd_tokenizer:
430+
sd_clip[k] = sd_tokenizer[k]
431+
return sd_clip
432+
426433
def load_model(self, tokens={}):
427434
memory_used = 0
428435
if hasattr(self.cond_stage_model, "memory_estimation_function"):
@@ -1908,7 +1915,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
19081915
load_models = [model]
19091916
if clip is not None:
19101917
load_models.append(clip.load_model())
1911-
clip_sd = clip.get_sd()
1918+
clip_sd = clip.state_dict_for_saving()
19121919
vae_sd = None
19131920
if vae is not None:
19141921
vae_sd = vae.get_sd()

comfy_extras/nodes_model_merging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,8 @@ def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
276276
for x in extra_pnginfo:
277277
metadata[x] = json.dumps(extra_pnginfo[x])
278278

279-
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
280-
clip_sd = clip.get_sd()
279+
clip.load_model()
280+
clip_sd = clip.state_dict_for_saving()
281281

282282
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
283283
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))

0 commit comments

Comments
 (0)