@@ -336,7 +336,10 @@ def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
336336 class Linear (torch .nn .Linear , CastWeightBiasOp ):
337337
338338 def __init__ (self , in_features , out_features , bias = True , device = None , dtype = None ):
339- if not comfy .model_management .WINDOWS or not comfy .memory_management .aimdo_enabled :
339+ # don't trust subclasses that BYO state dict loader to call us.
340+ if (not comfy .model_management .WINDOWS
341+ or not comfy .memory_management .aimdo_enabled
342+ or type (self )._load_from_state_dict is not disable_weight_init .Linear ._load_from_state_dict ):
340343 super ().__init__ (in_features , out_features , bias , device , dtype )
341344 return
342345
@@ -357,7 +360,9 @@ def __init__(self, in_features, out_features, bias=True, device=None, dtype=None
357360 def _load_from_state_dict (self , state_dict , prefix , local_metadata ,
358361 strict , missing_keys , unexpected_keys , error_msgs ):
359362
360- if not comfy .model_management .WINDOWS or not comfy .memory_management .aimdo_enabled :
363+ if (not comfy .model_management .WINDOWS
364+ or not comfy .memory_management .aimdo_enabled
365+ or type (self )._load_from_state_dict is not disable_weight_init .Linear ._load_from_state_dict ):
361366 return super ()._load_from_state_dict (state_dict , prefix , local_metadata , strict ,
362367 missing_keys , unexpected_keys , error_msgs )
363368 disable_weight_init ._lazy_load_from_state_dict (
@@ -564,7 +569,10 @@ class Embedding(torch.nn.Embedding, CastWeightBiasOp):
564569 def __init__ (self , num_embeddings , embedding_dim , padding_idx = None , max_norm = None ,
565570 norm_type = 2.0 , scale_grad_by_freq = False , sparse = False , _weight = None ,
566571 _freeze = False , device = None , dtype = None ):
567- if not comfy .model_management .WINDOWS or not comfy .memory_management .aimdo_enabled :
572+ # don't trust subclasses that BYO state dict loader to call us.
573+ if (not comfy .model_management .WINDOWS
574+ or not comfy .memory_management .aimdo_enabled
575+ or type (self )._load_from_state_dict is not disable_weight_init .Embedding ._load_from_state_dict ):
568576 super ().__init__ (num_embeddings , embedding_dim , padding_idx , max_norm ,
569577 norm_type , scale_grad_by_freq , sparse , _weight ,
570578 _freeze , device , dtype )
@@ -590,7 +598,9 @@ def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=Non
590598 def _load_from_state_dict (self , state_dict , prefix , local_metadata ,
591599 strict , missing_keys , unexpected_keys , error_msgs ):
592600
593- if not comfy .model_management .WINDOWS or not comfy .memory_management .aimdo_enabled :
601+ if (not comfy .model_management .WINDOWS
602+ or not comfy .memory_management .aimdo_enabled
603+ or type (self )._load_from_state_dict is not disable_weight_init .Embedding ._load_from_state_dict ):
594604 return super ()._load_from_state_dict (state_dict , prefix , local_metadata , strict ,
595605 missing_keys , unexpected_keys , error_msgs )
596606 disable_weight_init ._lazy_load_from_state_dict (
0 commit comments