@@ -30,8 +30,11 @@ class PatchEmbedding(nn.Module):
3030 """
3131
3232 def __init__ (
33- self , img_size : int = 224 , patch_size : int = 16 ,
34- in_channels : int = 3 , embed_dim : int = 768
33+ self ,
34+ img_size : int = 224 ,
35+ patch_size : int = 16 ,
36+ in_channels : int = 3 ,
37+ embed_dim : int = 768 ,
3538 ):
3639 super ().__init__ ()
3740 self .img_size = img_size
@@ -42,7 +45,7 @@ def __init__(
4245 in_channels = in_channels ,
4346 out_channels = embed_dim ,
4447 kernel_size = patch_size ,
45- stride = patch_size
48+ stride = patch_size ,
4649 )
4750
4851 def forward (self , x : Tensor ) -> Tensor :
@@ -107,7 +110,7 @@ def forward(self, x: Tensor) -> Tensor:
107110 q , k , v = qkv [0 ], qkv [1 ], qkv [2 ] # (B, num_heads, N, head_dim)
108111
109112 # Scaled dot-product attention
110- attn = (q @ k .transpose (- 2 , - 1 )) * (self .head_dim ** - 0.5 )
113+ attn = (q @ k .transpose (- 2 , - 1 )) * (self .head_dim ** - 0.5 )
111114 attn = functional .softmax (attn , dim = - 1 )
112115 attn = self .attn_dropout (attn )
113116
@@ -131,7 +134,9 @@ class MLPBlock(nn.Module):
131134 dropout (float): Dropout rate
132135 """
133136
134- def __init__ (self , embed_dim : int = 768 , mlp_ratio : float = 4.0 , dropout : float = 0.0 ):
137+ def __init__ (
138+ self , embed_dim : int = 768 , mlp_ratio : float = 4.0 , dropout : float = 0.0
139+ ):
135140 super ().__init__ ()
136141 hidden_dim = int (embed_dim * mlp_ratio )
137142
@@ -170,8 +175,11 @@ class TransformerEncoderBlock(nn.Module):
170175 """
171176
172177 def __init__ (
173- self , embed_dim : int = 768 , num_heads : int = 12 ,
174- mlp_ratio : float = 4.0 , dropout : float = 0.1
178+ self ,
179+ embed_dim : int = 768 ,
180+ num_heads : int = 12 ,
181+ mlp_ratio : float = 4.0 ,
182+ dropout : float = 0.1 ,
175183 ):
176184 super ().__init__ ()
177185
@@ -227,7 +235,7 @@ def __init__(
227235 num_heads : int = 12 ,
228236 mlp_ratio : float = 4.0 ,
229237 dropout : float = 0.1 ,
230- emb_dropout : float = 0.1
238+ emb_dropout : float = 0.1 ,
231239 ):
232240 super ().__init__ ()
233241
@@ -245,10 +253,12 @@ def __init__(
245253 self .pos_dropout = nn .Dropout (emb_dropout )
246254
247255 # Transformer encoder blocks
248- self .blocks = nn .ModuleList ([
249- TransformerEncoderBlock (embed_dim , num_heads , mlp_ratio , dropout )
250- for _ in range (depth )
251- ])
256+ self .blocks = nn .ModuleList (
257+ [
258+ TransformerEncoderBlock (embed_dim , num_heads , mlp_ratio , dropout )
259+ for _ in range (depth )
260+ ]
261+ )
252262
253263 # Layer normalization and classifier
254264 self .norm = nn .LayerNorm (embed_dim )
@@ -325,7 +335,7 @@ def create_vit_model(
325335 num_heads : int = 12 ,
326336 mlp_ratio : float = 4.0 ,
327337 dropout : float = 0.1 ,
328- emb_dropout : float = 0.1
338+ emb_dropout : float = 0.1 ,
329339) -> VisionTransformer :
330340 """
331341 Factory function to create a Vision Transformer model.
@@ -355,11 +365,13 @@ def create_vit_model(
355365 num_heads = num_heads ,
356366 mlp_ratio = mlp_ratio ,
357367 dropout = dropout ,
358- emb_dropout = emb_dropout
368+ emb_dropout = emb_dropout ,
359369 )
360370
361371
362- def get_pretrained_vit (model_name : str = "vit_base_patch16_224" , num_classes : int = 1000 ) -> nn .Module :
372+ def get_pretrained_vit (
373+ model_name : str = "vit_base_patch16_224" , num_classes : int = 1000
374+ ) -> nn .Module :
363375 """
364376 Load a pretrained ViT model from torchvision.
365377
@@ -377,9 +389,9 @@ def get_pretrained_vit(model_name: str = "vit_base_patch16_224", num_classes: in
377389 model = getattr (models , model_name )(pretrained = True )
378390 if num_classes != 1000 :
379391 # Replace the head for fine-tuning
380- if hasattr (model , ' heads' ):
392+ if hasattr (model , " heads" ):
381393 model .heads = nn .Linear (model .heads .in_features , num_classes )
382- elif hasattr (model , ' head' ):
394+ elif hasattr (model , " head" ):
383395 model .head = nn .Linear (model .head .in_features , num_classes )
384396 return model
385397 else :
@@ -411,7 +423,7 @@ def count_parameters(model: nn.Module) -> int:
411423 num_classes = 3 , # pizza, steak, sushi
412424 embed_dim = 768 ,
413425 depth = 12 ,
414- num_heads = 12
426+ num_heads = 12 ,
415427 )
416428
417429 print (f"Model created with { count_parameters (model ):,} parameters" )
0 commit comments