11"""
2- Vision Transformer (ViT) Implementation
2+ Vision Transformer (ViT) Implementation.
33
44This module contains a PyTorch implementation of the Vision Transformer (ViT)
5- architecture based on the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale".
5+ architecture based on the paper "An Image is Worth 16x16 Words:
6+ Transformers for Image Recognition at Scale".
67
78Key Components:
89- Patch Embedding
1213- Vision Transformer Model
1314"""
1415
15- import torch
16- import torch .nn as nn
17- import torch .nn .functional as F
18- from torch import Tensor
19- from typing import Optional , Tuple
20- import math
16+ from torch import Tensor , nn
17+ import torch .nn .functional as functional
2118
2219
2320class PatchEmbedding (nn .Module ):
2421 """
25- Creates patch embeddings from input images as described in Equation 1 of ViT paper .
22+ Creates patch embeddings from input images as described in Equation 1.
2623
2724 Args:
2825 img_size (int): Size of input image (assumed square)
@@ -32,11 +29,8 @@ class PatchEmbedding(nn.Module):
3229 """
3330
3431 def __init__ (
35- self ,
36- img_size : int = 224 ,
37- patch_size : int = 16 ,
38- in_channels : int = 3 ,
39- embed_dim : int = 768 ,
32+ self , img_size : int = 224 , patch_size : int = 16 ,
33+ in_channels : int = 3 , embed_dim : int = 768
4034 ):
4135 super ().__init__ ()
4236 self .img_size = img_size
@@ -47,7 +41,7 @@ def __init__(
4741 in_channels = in_channels ,
4842 out_channels = embed_dim ,
4943 kernel_size = patch_size ,
50- stride = patch_size ,
44+ stride = patch_size
5145 )
5246
5347 def forward (self , x : Tensor ) -> Tensor :
@@ -68,7 +62,7 @@ def forward(self, x: Tensor) -> Tensor:
6862
6963class MultiHeadSelfAttention (nn .Module ):
7064 """
71- Multi-Head Self Attention (MSA) block as described in Equation 2 of ViT paper .
65+ Multi-Head Self Attention (MSA) block as described in Equation 2.
7266
7367 Args:
7468 embed_dim (int): Dimension of embedding
@@ -101,23 +95,23 @@ def forward(self, x: Tensor) -> Tensor:
10195 Returns:
10296 Tensor: Output tensor of same shape as input
10397 """
104- B , N , C = x .shape
98+ batch_size , num_patches , channels = x .shape
10599
106100 # Create Q, K, V
107101 qkv = (
108102 self .qkv (x )
109- .reshape (B , N , 3 , self .num_heads , self .head_dim )
103+ .reshape (batch_size , num_patches , 3 , self .num_heads , self .head_dim )
110104 .permute (2 , 0 , 3 , 1 , 4 )
111105 )
112106 q , k , v = qkv [0 ], qkv [1 ], qkv [2 ] # (B, num_heads, N, head_dim)
113107
114108 # Scaled dot-product attention
115- attn = (q @ k .transpose (- 2 , - 1 )) * (self .head_dim ** - 0.5 ) # (B, num_heads, N, N )
116- attn = F .softmax (attn , dim = - 1 )
109+ attn = (q @ k .transpose (- 2 , - 1 )) * (self .head_dim ** - 0.5 )
110+ attn = functional .softmax (attn , dim = - 1 )
117111 attn = self .attn_dropout (attn )
118112
119113 # Apply attention to values
120- x = (attn @ v ).transpose (1 , 2 ).reshape (B , N , C ) # (B, N, embed_dim )
114+ x = (attn @ v ).transpose (1 , 2 ).reshape (batch_size , num_patches , channels )
121115
122116 # Projection
123117 x = self .proj (x )
@@ -128,17 +122,15 @@ def forward(self, x: Tensor) -> Tensor:
128122
129123class MLPBlock (nn .Module ):
130124 """
131- Multilayer Perceptron (MLP) block as described in Equation 3 of ViT paper .
125+ Multilayer Perceptron (MLP) block as described in Equation 3.
132126
133127 Args:
134128 embed_dim (int): Dimension of embedding
135129 mlp_ratio (float): Ratio of MLP hidden dimension to embed_dim
136130 dropout (float): Dropout rate
137131 """
138132
139- def __init__ (
140- self , embed_dim : int = 768 , mlp_ratio : float = 4.0 , dropout : float = 0.0
141- ):
133+ def __init__ (self , embed_dim : int = 768 , mlp_ratio : float = 4.0 , dropout : float = 0.0 ):
142134 super ().__init__ ()
143135 hidden_dim = int (embed_dim * mlp_ratio )
144136
@@ -177,11 +169,8 @@ class TransformerEncoderBlock(nn.Module):
177169 """
178170
179171 def __init__ (
180- self ,
181- embed_dim : int = 768 ,
182- num_heads : int = 12 ,
183- mlp_ratio : float = 4.0 ,
184- dropout : float = 0.1 ,
172+ self , embed_dim : int = 768 , num_heads : int = 12 ,
173+ mlp_ratio : float = 4.0 , dropout : float = 0.1
185174 ):
186175 super ().__init__ ()
187176
@@ -237,7 +226,7 @@ def __init__(
237226 num_heads : int = 12 ,
238227 mlp_ratio : float = 4.0 ,
239228 dropout : float = 0.1 ,
240- emb_dropout : float = 0.1 ,
229+ emb_dropout : float = 0.1
241230 ):
242231 super ().__init__ ()
243232
@@ -255,12 +244,10 @@ def __init__(
255244 self .pos_dropout = nn .Dropout (emb_dropout )
256245
257246 # Transformer encoder blocks
258- self .blocks = nn .ModuleList (
259- [
260- TransformerEncoderBlock (embed_dim , num_heads , mlp_ratio , dropout )
261- for _ in range (depth )
262- ]
263- )
247+ self .blocks = nn .ModuleList ([
248+ TransformerEncoderBlock (embed_dim , num_heads , mlp_ratio , dropout )
249+ for _ in range (depth )
250+ ])
264251
265252 # Layer normalization and classifier
266253 self .norm = nn .LayerNorm (embed_dim )
@@ -300,14 +287,14 @@ def forward(self, x: Tensor) -> Tensor:
300287 Returns:
301288 Tensor: Output logits of shape (B, num_classes)
302289 """
303- B = x .shape [0 ]
290+ batch_size = x .shape [0 ]
304291
305292 # Create patch embeddings
306293 x = self .patch_embed (x ) # (B, n_patches, embed_dim)
307294
308295 # Add class token
309- cls_tokens = self .cls_token .expand (B , - 1 , - 1 ) # (B, 1, embed_dim )
310- x = torch .cat ((cls_tokens , x ), dim = 1 ) # (B, n_patches + 1, embed_dim)
296+ cls_tokens = self .cls_token .expand (batch_size , - 1 , - 1 )
297+ x = torch .cat ((cls_tokens , x ), dim = 1 )
311298
312299 # Add position embedding and apply dropout
313300 x = x + self .pos_embed
@@ -337,7 +324,7 @@ def create_vit_model(
337324 num_heads : int = 12 ,
338325 mlp_ratio : float = 4.0 ,
339326 dropout : float = 0.1 ,
340- emb_dropout : float = 0.1 ,
327+ emb_dropout : float = 0.1
341328) -> VisionTransformer :
342329 """
343330 Factory function to create a Vision Transformer model.
@@ -367,13 +354,11 @@ def create_vit_model(
367354 num_heads = num_heads ,
368355 mlp_ratio = mlp_ratio ,
369356 dropout = dropout ,
370- emb_dropout = emb_dropout ,
357+ emb_dropout = emb_dropout
371358 )
372359
373360
374- def get_pretrained_vit (
375- model_name : str = "vit_base_patch16_224" , num_classes : int = 1000
376- ) -> nn .Module :
361+ def get_pretrained_vit (model_name : str = "vit_base_patch16_224" , num_classes : int = 1000 ) -> nn .Module :
377362 """
378363 Load a pretrained ViT model from torchvision.
379364
@@ -385,19 +370,20 @@ def get_pretrained_vit(
385370 nn.Module: Pretrained ViT model
386371 """
387372 try :
388- import torchvision . models as models
373+ from torchvision import models
389374
390375 if hasattr (models , model_name ):
391376 model = getattr (models , model_name )(pretrained = True )
392377 if num_classes != 1000 :
393378 # Replace the head for fine-tuning
394- if hasattr (model , " heads" ):
379+ if hasattr (model , ' heads' ):
395380 model .heads = nn .Linear (model .heads .in_features , num_classes )
396- elif hasattr (model , " head" ):
381+ elif hasattr (model , ' head' ):
397382 model .head = nn .Linear (model .head .in_features , num_classes )
398383 return model
399384 else :
400- raise ValueError (f"Model { model_name } not found in torchvision.models" )
385+ error_msg = f"Model { model_name } not found in torchvision.models"
386+ raise ValueError (error_msg )
401387
402388 except ImportError :
403389 raise ImportError ("torchvision is required to load pretrained models" )
@@ -424,7 +410,7 @@ def count_parameters(model: nn.Module) -> int:
424410 num_classes = 3 , # pizza, steak, sushi
425411 embed_dim = 768 ,
426412 depth = 12 ,
427- num_heads = 12 ,
413+ num_heads = 12
428414 )
429415
430416 print (f"Model created with { count_parameters (model ):,} parameters" )
0 commit comments