Skip to content

Commit e84a200

Browse files
authored
ops: opt out of deferred weight init if subclassed (Comfy-Org#12967)
If a subclass BYO _load_from_state_dict and doesnt call the super() the needed default init of these weights is missed and can lead to problems for uninitialized weights.
1 parent 192cb8e commit e84a200

1 file changed

Lines changed: 14 additions & 4 deletions

File tree

comfy/ops.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,10 @@ def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
336336
class Linear(torch.nn.Linear, CastWeightBiasOp):
337337

338338
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
339-
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
339+
# don't trust subclasses that BYO state dict loader to call us.
340+
if (not comfy.model_management.WINDOWS
341+
or not comfy.memory_management.aimdo_enabled
342+
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
340343
super().__init__(in_features, out_features, bias, device, dtype)
341344
return
342345

@@ -357,7 +360,9 @@ def __init__(self, in_features, out_features, bias=True, device=None, dtype=None
357360
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
358361
strict, missing_keys, unexpected_keys, error_msgs):
359362

360-
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
363+
if (not comfy.model_management.WINDOWS
364+
or not comfy.memory_management.aimdo_enabled
365+
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
361366
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
362367
missing_keys, unexpected_keys, error_msgs)
363368
disable_weight_init._lazy_load_from_state_dict(
@@ -564,7 +569,10 @@ class Embedding(torch.nn.Embedding, CastWeightBiasOp):
564569
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
565570
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
566571
_freeze=False, device=None, dtype=None):
567-
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
572+
# don't trust subclasses that BYO state dict loader to call us.
573+
if (not comfy.model_management.WINDOWS
574+
or not comfy.memory_management.aimdo_enabled
575+
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
568576
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
569577
norm_type, scale_grad_by_freq, sparse, _weight,
570578
_freeze, device, dtype)
@@ -590,7 +598,9 @@ def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=Non
590598
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
591599
strict, missing_keys, unexpected_keys, error_msgs):
592600

593-
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
601+
if (not comfy.model_management.WINDOWS
602+
or not comfy.memory_management.aimdo_enabled
603+
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
594604
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
595605
missing_keys, unexpected_keys, error_msgs)
596606
disable_weight_init._lazy_load_from_state_dict(

0 commit comments

Comments
 (0)