@@ -29,8 +29,11 @@ class PatchEmbedding(nn.Module):
2929 """
3030
3131 def __init__ (
32- self , img_size : int = 224 , patch_size : int = 16 ,
33- in_channels : int = 3 , embed_dim : int = 768
32+ self ,
33+ img_size : int = 224 ,
34+ patch_size : int = 16 ,
35+ in_channels : int = 3 ,
36+ embed_dim : int = 768 ,
3437 ):
3538 super ().__init__ ()
3639 self .img_size = img_size
@@ -41,7 +44,7 @@ def __init__(
4144 in_channels = in_channels ,
4245 out_channels = embed_dim ,
4346 kernel_size = patch_size ,
44- stride = patch_size
47+ stride = patch_size ,
4548 )
4649
4750 def forward (self , x : Tensor ) -> Tensor :
@@ -106,7 +109,7 @@ def forward(self, x: Tensor) -> Tensor:
106109 q , k , v = qkv [0 ], qkv [1 ], qkv [2 ] # (B, num_heads, N, head_dim)
107110
108111 # Scaled dot-product attention
109- attn = (q @ k .transpose (- 2 , - 1 )) * (self .head_dim ** - 0.5 )
112+ attn = (q @ k .transpose (- 2 , - 1 )) * (self .head_dim ** - 0.5 )
110113 attn = functional .softmax (attn , dim = - 1 )
111114 attn = self .attn_dropout (attn )
112115
@@ -130,7 +133,9 @@ class MLPBlock(nn.Module):
130133 dropout (float): Dropout rate
131134 """
132135
133- def __init__ (self , embed_dim : int = 768 , mlp_ratio : float = 4.0 , dropout : float = 0.0 ):
136+ def __init__ (
137+ self , embed_dim : int = 768 , mlp_ratio : float = 4.0 , dropout : float = 0.0
138+ ):
134139 super ().__init__ ()
135140 hidden_dim = int (embed_dim * mlp_ratio )
136141
@@ -169,8 +174,11 @@ class TransformerEncoderBlock(nn.Module):
169174 """
170175
171176 def __init__ (
172- self , embed_dim : int = 768 , num_heads : int = 12 ,
173- mlp_ratio : float = 4.0 , dropout : float = 0.1
177+ self ,
178+ embed_dim : int = 768 ,
179+ num_heads : int = 12 ,
180+ mlp_ratio : float = 4.0 ,
181+ dropout : float = 0.1 ,
174182 ):
175183 super ().__init__ ()
176184
@@ -226,7 +234,7 @@ def __init__(
226234 num_heads : int = 12 ,
227235 mlp_ratio : float = 4.0 ,
228236 dropout : float = 0.1 ,
229- emb_dropout : float = 0.1
237+ emb_dropout : float = 0.1 ,
230238 ):
231239 super ().__init__ ()
232240
@@ -244,10 +252,12 @@ def __init__(
244252 self .pos_dropout = nn .Dropout (emb_dropout )
245253
246254 # Transformer encoder blocks
247- self .blocks = nn .ModuleList ([
248- TransformerEncoderBlock (embed_dim , num_heads , mlp_ratio , dropout )
249- for _ in range (depth )
250- ])
255+ self .blocks = nn .ModuleList (
256+ [
257+ TransformerEncoderBlock (embed_dim , num_heads , mlp_ratio , dropout )
258+ for _ in range (depth )
259+ ]
260+ )
251261
252262 # Layer normalization and classifier
253263 self .norm = nn .LayerNorm (embed_dim )
@@ -324,7 +334,7 @@ def create_vit_model(
324334 num_heads : int = 12 ,
325335 mlp_ratio : float = 4.0 ,
326336 dropout : float = 0.1 ,
327- emb_dropout : float = 0.1
337+ emb_dropout : float = 0.1 ,
328338) -> VisionTransformer :
329339 """
330340 Factory function to create a Vision Transformer model.
@@ -354,11 +364,13 @@ def create_vit_model(
354364 num_heads = num_heads ,
355365 mlp_ratio = mlp_ratio ,
356366 dropout = dropout ,
357- emb_dropout = emb_dropout
367+ emb_dropout = emb_dropout ,
358368 )
359369
360370
361- def get_pretrained_vit (model_name : str = "vit_base_patch16_224" , num_classes : int = 1000 ) -> nn .Module :
371+ def get_pretrained_vit (
372+ model_name : str = "vit_base_patch16_224" , num_classes : int = 1000
373+ ) -> nn .Module :
362374 """
363375 Load a pretrained ViT model from torchvision.
364376
@@ -376,9 +388,9 @@ def get_pretrained_vit(model_name: str = "vit_base_patch16_224", num_classes: in
376388 model = getattr (models , model_name )(pretrained = True )
377389 if num_classes != 1000 :
378390 # Replace the head for fine-tuning
379- if hasattr (model , ' heads' ):
391+ if hasattr (model , " heads" ):
380392 model .heads = nn .Linear (model .heads .in_features , num_classes )
381- elif hasattr (model , ' head' ):
393+ elif hasattr (model , " head" ):
382394 model .head = nn .Linear (model .head .in_features , num_classes )
383395 return model
384396 else :
@@ -410,7 +422,7 @@ def count_parameters(model: nn.Module) -> int:
410422 num_classes = 3 , # pizza, steak, sushi
411423 embed_dim = 768 ,
412424 depth = 12 ,
413- num_heads = 12
425+ num_heads = 12 ,
414426 )
415427
416428 print (f"Model created with { count_parameters (model ):,} parameters" )
0 commit comments