33
44import torch
55import comfy .model_management
6- from .base import WeightAdapterBase , weight_decompose
6+ from .base import WeightAdapterBase , WeightAdapterTrainBase , weight_decompose
7+
8+
9+ class HadaWeight (torch .autograd .Function ):
10+ @staticmethod
11+ def forward (ctx , w1u , w1d , w2u , w2d , scale = torch .tensor (1 )):
12+ ctx .save_for_backward (w1d , w1u , w2d , w2u , scale )
13+ diff_weight = ((w1u @ w1d ) * (w2u @ w2d )) * scale
14+ return diff_weight
15+
16+ @staticmethod
17+ def backward (ctx , grad_out ):
18+ (w1d , w1u , w2d , w2u , scale ) = ctx .saved_tensors
19+ grad_out = grad_out * scale
20+ temp = grad_out * (w2u @ w2d )
21+ grad_w1u = temp @ w1d .T
22+ grad_w1d = w1u .T @ temp
23+
24+ temp = grad_out * (w1u @ w1d )
25+ grad_w2u = temp @ w2d .T
26+ grad_w2d = w2u .T @ temp
27+
28+ del temp
29+ return grad_w1u , grad_w1d , grad_w2u , grad_w2d , None
30+
31+
32+ class HadaWeightTucker (torch .autograd .Function ):
33+ @staticmethod
34+ def forward (ctx , t1 , w1u , w1d , t2 , w2u , w2d , scale = torch .tensor (1 )):
35+ ctx .save_for_backward (t1 , w1d , w1u , t2 , w2d , w2u , scale )
36+
37+ rebuild1 = torch .einsum ("i j ..., j r, i p -> p r ..." , t1 , w1d , w1u )
38+ rebuild2 = torch .einsum ("i j ..., j r, i p -> p r ..." , t2 , w2d , w2u )
39+
40+ return rebuild1 * rebuild2 * scale
41+
42+ @staticmethod
43+ def backward (ctx , grad_out ):
44+ (t1 , w1d , w1u , t2 , w2d , w2u , scale ) = ctx .saved_tensors
45+ grad_out = grad_out * scale
46+
47+ temp = torch .einsum ("i j ..., j r -> i r ..." , t2 , w2d )
48+ rebuild = torch .einsum ("i j ..., i r -> r j ..." , temp , w2u )
49+
50+ grad_w = rebuild * grad_out
51+ del rebuild
52+
53+ grad_w1u = torch .einsum ("r j ..., i j ... -> r i" , temp , grad_w )
54+ grad_temp = torch .einsum ("i j ..., i r -> r j ..." , grad_w , w1u .T )
55+ del grad_w , temp
56+
57+ grad_w1d = torch .einsum ("i r ..., i j ... -> r j" , t1 , grad_temp )
58+ grad_t1 = torch .einsum ("i j ..., j r -> i r ..." , grad_temp , w1d .T )
59+ del grad_temp
60+
61+ temp = torch .einsum ("i j ..., j r -> i r ..." , t1 , w1d )
62+ rebuild = torch .einsum ("i j ..., i r -> r j ..." , temp , w1u )
63+
64+ grad_w = rebuild * grad_out
65+ del rebuild
66+
67+ grad_w2u = torch .einsum ("r j ..., i j ... -> r i" , temp , grad_w )
68+ grad_temp = torch .einsum ("i j ..., i r -> r j ..." , grad_w , w2u .T )
69+ del grad_w , temp
70+
71+ grad_w2d = torch .einsum ("i r ..., i j ... -> r j" , t2 , grad_temp )
72+ grad_t2 = torch .einsum ("i j ..., j r -> i r ..." , grad_temp , w2d .T )
73+ del grad_temp
74+ return grad_t1 , grad_w1u , grad_w1d , grad_t2 , grad_w2u , grad_w2d , None
75+
76+
77+ class LohaDiff (WeightAdapterTrainBase ):
78+ def __init__ (self , weights ):
79+ super ().__init__ ()
80+ # Unpack weights tuple from LoHaAdapter
81+ w1a , w1b , alpha , w2a , w2b , t1 , t2 , _ = weights
82+
83+ # Create trainable parameters
84+ self .hada_w1_a = torch .nn .Parameter (w1a )
85+ self .hada_w1_b = torch .nn .Parameter (w1b )
86+ self .hada_w2_a = torch .nn .Parameter (w2a )
87+ self .hada_w2_b = torch .nn .Parameter (w2b )
88+
89+ self .use_tucker = False
90+ if t1 is not None and t2 is not None :
91+ self .use_tucker = True
92+ self .hada_t1 = torch .nn .Parameter (t1 )
93+ self .hada_t2 = torch .nn .Parameter (t2 )
94+ else :
95+ # Keep the attributes for consistent access
96+ self .hada_t1 = None
97+ self .hada_t2 = None
98+
99+ # Store rank and non-trainable alpha
100+ self .rank = w1b .shape [0 ]
101+ self .alpha = torch .nn .Parameter (torch .tensor (alpha ), requires_grad = False )
102+
103+ def __call__ (self , w ):
104+ org_dtype = w .dtype
105+
106+ scale = self .alpha / self .rank
107+ if self .use_tucker :
108+ diff_weight = HadaWeightTucker .apply (self .hada_t1 , self .hada_w1_a , self .hada_w1_b , self .hada_t2 , self .hada_w2_a , self .hada_w2_b , scale )
109+ else :
110+ diff_weight = HadaWeight .apply (self .hada_w1_a , self .hada_w1_b , self .hada_w2_a , self .hada_w2_b , scale )
111+
112+ # Add the scaled difference to the original weight
113+ weight = w .to (diff_weight ) + diff_weight .reshape (w .shape )
114+
115+ return weight .to (org_dtype )
116+
117+ def passive_memory_usage (self ):
118+ """Calculates memory usage of the trainable parameters."""
119+ return sum (param .numel () * param .element_size () for param in self .parameters ())
7120
8121
9122class LoHaAdapter (WeightAdapterBase ):
@@ -13,6 +126,25 @@ def __init__(self, loaded_keys, weights):
13126 self .loaded_keys = loaded_keys
14127 self .weights = weights
15128
129+ @classmethod
130+ def create_train (cls , weight , rank = 1 , alpha = 1.0 ):
131+ out_dim = weight .shape [0 ]
132+ in_dim = weight .shape [1 :].numel ()
133+ mat1 = torch .empty (out_dim , rank , device = weight .device , dtype = weight .dtype )
134+ mat2 = torch .empty (rank , in_dim , device = weight .device , dtype = weight .dtype )
135+ torch .nn .init .normal_ (mat1 , 0.1 )
136+ torch .nn .init .constant_ (mat2 , 0.0 )
137+ mat3 = torch .empty (out_dim , rank , device = weight .device , dtype = weight .dtype )
138+ mat4 = torch .empty (rank , in_dim , device = weight .device , dtype = weight .dtype )
139+ torch .nn .init .normal_ (mat3 , 0.1 )
140+ torch .nn .init .normal_ (mat4 , 0.01 )
141+ return LohaDiff (
142+ (mat1 , mat2 , alpha , mat3 , mat4 , None , None , None )
143+ )
144+
145+ def to_train (self ):
146+ return LohaDiff (self .weights )
147+
16148 @classmethod
17149 def load (
18150 cls ,
0 commit comments