@@ -1493,38 +1493,45 @@ def clean_hooks(self):
14931493 self .unpatch_hooks ()
14941494 self .clear_cached_hook_weights ()
14951495
1496- def state_dict_for_saving (self , clip_state_dict = None , vae_state_dict = None , clip_vision_state_dict = None ):
1497- original_state_dict = self .model .diffusion_model .state_dict ()
1498- unet_state_dict = {}
1496+ def model_state_dict_for_saving (self , model = None , prefix = "" ):
1497+ if model is None :
1498+ model = self .model
1499+
1500+ original_state_dict = model .state_dict ()
1501+ output_state_dict = {}
14991502 keys = list (original_state_dict )
15001503 while len (keys ) > 0 :
15011504 k = keys .pop (0 )
15021505 v = original_state_dict [k ]
15031506 op_keys = k .rsplit ('.' , 1 )
15041507 if (len (op_keys ) < 2 ) or op_keys [1 ] not in ["weight" , "bias" ]:
1505- unet_state_dict [k ] = v
1508+ output_state_dict [k ] = v
15061509 continue
15071510 try :
1508- op = comfy .utils .get_attr (self . model . diffusion_model , op_keys [0 ])
1511+ op = comfy .utils .get_attr (model , op_keys [0 ])
15091512 except :
1510- unet_state_dict [k ] = v
1513+ output_state_dict [k ] = v
15111514 continue
15121515 if not op or not hasattr (op , "comfy_cast_weights" ) or \
15131516 (hasattr (op , "comfy_patched_weights" ) and op .comfy_patched_weights == True ):
1514- unet_state_dict [k ] = v
1517+ output_state_dict [k ] = v
15151518 continue
1516- key = "diffusion_model." + k
1519+ key = prefix + k
15171520 weight = comfy .utils .get_attr (self .model , key )
15181521 if isinstance (weight , QuantizedTensor ) and k in original_state_dict :
15191522 qt_state_dict = weight .state_dict (k )
15201523 caster = LazyCastingQuantizedParam (self , key )
15211524 for group_key in (x for x in qt_state_dict if x in original_state_dict ):
15221525 if group_key in keys :
15231526 keys .remove (group_key )
1524- unet_state_dict .pop (group_key , "" )
1525- unet_state_dict [group_key ] = LazyCastingParamPiece (caster , "diffusion_model." + group_key , original_state_dict [group_key ])
1527+ output_state_dict .pop (group_key , "" )
1528+ output_state_dict [group_key ] = LazyCastingParamPiece (caster , prefix + group_key , original_state_dict [group_key ])
15261529 continue
1527- unet_state_dict [k ] = LazyCastingParam (self , key , weight )
1530+ output_state_dict [k ] = LazyCastingParam (self , key , weight )
1531+ return output_state_dict
1532+
1533+ def state_dict_for_saving (self , clip_state_dict = None , vae_state_dict = None , clip_vision_state_dict = None ):
1534+ unet_state_dict = self .model_state_dict_for_saving (self .model .diffusion_model , "diffusion_model." )
15281535 return self .model .state_dict_for_saving (unet_state_dict , clip_state_dict = clip_state_dict , vae_state_dict = vae_state_dict , clip_vision_state_dict = clip_vision_state_dict )
15291536
15301537 def __del__ (self ):
0 commit comments