@@ -17,10 +17,227 @@ def __getitem__(self, key):
1717 def __setitem__ (self , key , item ):
1818 setattr (self , key , item )
1919
20- def clip_preprocess (image , size = 224 , mean = [0.48145466 , 0.4578275 , 0.40821073 ], std = [0.26862954 , 0.26130258 , 0.27577711 ], crop = True ):
20+
21+ def cubic_kernel (x , a : float = - 0.75 ):
22+ absx = x .abs ()
23+ absx2 = absx ** 2
24+ absx3 = absx ** 3
25+
26+ w = (a + 2 ) * absx3 - (a + 3 ) * absx2 + 1
27+ w2 = a * absx3 - 5 * a * absx2 + 8 * a * absx - 4 * a
28+
29+ return torch .where (absx <= 1 , w , torch .where (absx < 2 , w2 , torch .zeros_like (x )))
30+
31+ def get_indices_weights (in_size , out_size , scale ):
32+ # OpenCV-style half-pixel mapping
33+ x = torch .arange (out_size , dtype = torch .float32 )
34+ x = (x + 0.5 ) / scale - 0.5
35+
36+ x0 = x .floor ().long ()
37+ dx = x .unsqueeze (1 ) - (x0 .unsqueeze (1 ) + torch .arange (- 1 , 3 ))
38+
39+ weights = cubic_kernel (dx )
40+ weights = weights / weights .sum (dim = 1 , keepdim = True )
41+
42+ indices = x0 .unsqueeze (1 ) + torch .arange (- 1 , 3 )
43+ indices = indices .clamp (0 , in_size - 1 )
44+
45+ return indices , weights
46+
47+ def resize_cubic_1d (x , out_size , dim ):
48+ b , c , h , w = x .shape
49+ in_size = h if dim == 2 else w
50+ scale = out_size / in_size
51+
52+ indices , weights = get_indices_weights (in_size , out_size , scale )
53+
54+ if dim == 2 :
55+ x = x .permute (0 , 1 , 3 , 2 )
56+ x = x .reshape (- 1 , h )
57+ else :
58+ x = x .reshape (- 1 , w )
59+
60+ gathered = x [:, indices ]
61+ out = (gathered * weights .unsqueeze (0 )).sum (dim = 2 )
62+
63+ if dim == 2 :
64+ out = out .reshape (b , c , w , out_size ).permute (0 , 1 , 3 , 2 )
65+ else :
66+ out = out .reshape (b , c , h , out_size )
67+
68+ return out
69+
70+ def resize_cubic (img : torch .Tensor , size : tuple ) -> torch .Tensor :
71+ """
72+ Resize image using OpenCV-equivalent INTER_CUBIC interpolation.
73+ Implemented in pure PyTorch
74+ """
75+
76+ if img .ndim == 3 :
77+ img = img .unsqueeze (0 )
78+
79+ img = img .permute (0 , 3 , 1 , 2 )
80+
81+ out_h , out_w = size
82+ img = resize_cubic_1d (img , out_h , dim = 2 )
83+ img = resize_cubic_1d (img , out_w , dim = 3 )
84+ return img
85+
86+ def resize_area (img : torch .Tensor , size : tuple ) -> torch .Tensor :
87+ # vectorized implementation for OpenCV's INTER_AREA using pure PyTorch
88+ original_shape = img .shape
89+ is_hwc = False
90+
91+ if img .ndim == 3 :
92+ if img .shape [0 ] <= 4 :
93+ img = img .unsqueeze (0 )
94+ else :
95+ is_hwc = True
96+ img = img .permute (2 , 0 , 1 ).unsqueeze (0 )
97+ elif img .ndim == 4 :
98+ pass
99+ else :
100+ raise ValueError ("Expected image with 3 or 4 dims." )
101+
102+ B , C , H , W = img .shape
103+ out_h , out_w = size
104+ scale_y = H / out_h
105+ scale_x = W / out_w
106+
107+ device = img .device
108+
109+ # compute the grid boundries
110+ y_start = torch .arange (out_h , device = device ).float () * scale_y
111+ y_end = y_start + scale_y
112+ x_start = torch .arange (out_w , device = device ).float () * scale_x
113+ x_end = x_start + scale_x
114+
115+ # for each output pixel, we will compute the range for it
116+ y_start_int = torch .floor (y_start ).long ()
117+ y_end_int = torch .ceil (y_end ).long ()
118+ x_start_int = torch .floor (x_start ).long ()
119+ x_end_int = torch .ceil (x_end ).long ()
120+
121+ # We will build the weighted sums by iterating over contributing input pixels once
122+ output = torch .zeros ((B , C , out_h , out_w ), dtype = torch .float32 , device = device )
123+ area = torch .zeros ((out_h , out_w ), dtype = torch .float32 , device = device )
124+
125+ max_kernel_h = int (torch .max (y_end_int - y_start_int ).item ())
126+ max_kernel_w = int (torch .max (x_end_int - x_start_int ).item ())
127+
128+ for dy in range (max_kernel_h ):
129+ for dx in range (max_kernel_w ):
130+ # compute the weights for this offset for all output pixels
131+
132+ y_idx = y_start_int .unsqueeze (1 ) + dy
133+ x_idx = x_start_int .unsqueeze (0 ) + dx
134+
135+ # clamp indices to image boundaries
136+ y_idx_clamped = torch .clamp (y_idx , 0 , H - 1 )
137+ x_idx_clamped = torch .clamp (x_idx , 0 , W - 1 )
138+
139+ # compute weights by broadcasting
140+ y_weight = (torch .min (y_end .unsqueeze (1 ), y_idx_clamped .float () + 1.0 ) - torch .max (y_start .unsqueeze (1 ), y_idx_clamped .float ())).clamp (min = 0 )
141+ x_weight = (torch .min (x_end .unsqueeze (0 ), x_idx_clamped .float () + 1.0 ) - torch .max (x_start .unsqueeze (0 ), x_idx_clamped .float ())).clamp (min = 0 )
142+
143+ weight = (y_weight * x_weight )
144+
145+ y_expand = y_idx_clamped .expand (out_h , out_w )
146+ x_expand = x_idx_clamped .expand (out_h , out_w )
147+
148+
149+ pixels = img [:, :, y_expand , x_expand ]
150+
151+ # unsqueeze to broadcast
152+ w = weight .unsqueeze (0 ).unsqueeze (0 )
153+
154+ output += pixels * w
155+ area += weight
156+
157+ # Normalize by area
158+ output /= area .unsqueeze (0 ).unsqueeze (0 )
159+
160+ if is_hwc :
161+ return output [0 ].permute (1 , 2 , 0 )
162+ elif img .shape [0 ] == 1 and original_shape [0 ] <= 4 :
163+ return output [0 ]
164+ else :
165+ return output
166+
167+ def recenter (image , border_ratio : float = 0.2 ):
168+
169+ if image .shape [- 1 ] == 4 :
170+ mask = image [..., 3 ]
171+ else :
172+ mask = torch .ones_like (image [..., 0 :1 ]) * 255
173+ image = torch .concatenate ([image , mask ], axis = - 1 )
174+ mask = mask [..., 0 ]
175+
176+ H , W , C = image .shape
177+
178+ size = max (H , W )
179+ result = torch .zeros ((size , size , C ), dtype = torch .uint8 )
180+
181+ # as_tuple to match numpy behaviour
182+ x_coords , y_coords = torch .nonzero (mask , as_tuple = True )
183+
184+ y_min , y_max = y_coords .min (), y_coords .max ()
185+ x_min , x_max = x_coords .min (), x_coords .max ()
186+
187+ h = x_max - x_min
188+ w = y_max - y_min
189+
190+ if h == 0 or w == 0 :
191+ raise ValueError ('input image is empty' )
192+
193+ desired_size = int (size * (1 - border_ratio ))
194+ scale = desired_size / max (h , w )
195+
196+ h2 = int (h * scale )
197+ w2 = int (w * scale )
198+
199+ x2_min = (size - h2 ) // 2
200+ x2_max = x2_min + h2
201+
202+ y2_min = (size - w2 ) // 2
203+ y2_max = y2_min + w2
204+
205+ # note: opencv takes columns first (opposite to pytorch and numpy that take the row first)
206+ result [x2_min :x2_max , y2_min :y2_max ] = resize_area (image [x_min :x_max , y_min :y_max ], (h2 , w2 ))
207+
208+ bg = torch .ones ((result .shape [0 ], result .shape [1 ], 3 ), dtype = torch .uint8 ) * 255
209+
210+ mask = result [..., 3 :].to (torch .float32 ) / 255
211+ result = result [..., :3 ] * mask + bg * (1 - mask )
212+
213+ mask = mask * 255
214+ result = result .clip (0 , 255 ).to (torch .uint8 )
215+ mask = mask .clip (0 , 255 ).to (torch .uint8 )
216+
217+ return result
218+
219+ def clip_preprocess (image , size = 224 , mean = [0.48145466 , 0.4578275 , 0.40821073 ], std = [0.26862954 , 0.26130258 , 0.27577711 ],
220+ crop = True , value_range = (- 1 , 1 ), border_ratio : float = None , recenter_size : int = 512 ):
221+
222+ if border_ratio is not None :
223+
224+ image = (image * 255 ).clamp (0 , 255 ).to (torch .uint8 )
225+ image = [recenter (img , border_ratio = border_ratio ) for img in image ]
226+
227+ image = torch .stack (image , dim = 0 )
228+ image = resize_cubic (image , size = (recenter_size , recenter_size ))
229+
230+ image = image / 255 * 2 - 1
231+ low , high = value_range
232+
233+ image = (image - low ) / (high - low )
234+ image = image .permute (0 , 2 , 3 , 1 )
235+
21236 image = image [:, :, :, :3 ] if image .shape [3 ] > 3 else image
237+
22238 mean = torch .tensor (mean , device = image .device , dtype = image .dtype )
23239 std = torch .tensor (std , device = image .device , dtype = image .dtype )
240+
24241 image = image .movedim (- 1 , 1 )
25242 if not (image .shape [2 ] == size and image .shape [3 ] == size ):
26243 if crop :
@@ -29,7 +246,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
29246 else :
30247 scale_size = (size , size )
31248
32- image = torch .nn .functional .interpolate (image , size = scale_size , mode = "bicubic" , antialias = True )
249+ image = torch .nn .functional .interpolate (image , size = scale_size , mode = "bilinear" if border_ratio is not None else " bicubic" , antialias = True )
33250 h = (image .shape [2 ] - size )// 2
34251 w = (image .shape [3 ] - size )// 2
35252 image = image [:,:,h :h + size ,w :w + size ]
@@ -71,9 +288,9 @@ def load_sd(self, sd):
71288 def get_sd (self ):
72289 return self .model .state_dict ()
73290
74- def encode_image (self , image , crop = True ):
291+ def encode_image (self , image , crop = True , border_ratio : float = None ):
75292 comfy .model_management .load_model_gpu (self .patcher )
76- pixel_values = clip_preprocess (image .to (self .load_device ), size = self .image_size , mean = self .image_mean , std = self .image_std , crop = crop ).float ()
293+ pixel_values = clip_preprocess (image .to (self .load_device ), size = self .image_size , mean = self .image_mean , std = self .image_std , crop = crop , border_ratio = border_ratio ).float ()
77294 out = self .model (pixel_values = pixel_values , intermediate_output = 'all' if self .return_all_hidden_states else - 2 )
78295
79296 outputs = Output ()
@@ -136,8 +353,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
136353 json_config = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "clip_vision_config_vitl_336.json" )
137354 else :
138355 json_config = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "clip_vision_config_vitl.json" )
139- elif "embeddings.patch_embeddings.projection.weight" in sd :
356+
357+ # Dinov2
358+ elif 'encoder.layer.39.layer_scale2.lambda1' in sd :
140359 json_config = os .path .join (os .path .join (os .path .dirname (os .path .realpath (__file__ )), "image_encoders" ), "dino2_giant.json" )
360+ elif 'encoder.layer.23.layer_scale2.lambda1' in sd :
361+ json_config = os .path .join (os .path .join (os .path .dirname (os .path .realpath (__file__ )), "image_encoders" ), "dino2_large.json" )
141362 else :
142363 return None
143364
0 commit comments