1+ """
2+ Vision Transformer (ViT) Implementation
3+
4+ This module contains a PyTorch implementation of the Vision Transformer (ViT)
5+ architecture based on the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale".
6+
7+ Key Components:
8+ - Patch Embedding
9+ - Multi-Head Self Attention
10+ - MLP Block
11+ - Transformer Encoder
12+ - Vision Transformer Model
13+ """
14+
15+ import torch
16+ import torch .nn as nn
17+ import torch .nn .functional as F
18+ from torch import Tensor
19+ from typing import Optional , Tuple
20+ import math
21+
22+
23+ class PatchEmbedding (nn .Module ):
24+ """
25+ Creates patch embeddings from input images as described in Equation 1 of ViT paper.
26+
27+ Args:
28+ img_size (int): Size of input image (assumed square)
29+ patch_size (int): Size of each patch (assumed square)
30+ in_channels (int): Number of input channels
31+ embed_dim (int): Dimension of embedding
32+ """
33+
34+ def __init__ (self , img_size : int = 224 , patch_size : int = 16 , in_channels : int = 3 , embed_dim : int = 768 ):
35+ super ().__init__ ()
36+ self .img_size = img_size
37+ self .patch_size = patch_size
38+ self .n_patches = (img_size // patch_size ) ** 2
39+
40+ self .proj = nn .Conv2d (
41+ in_channels = in_channels ,
42+ out_channels = embed_dim ,
43+ kernel_size = patch_size ,
44+ stride = patch_size
45+ )
46+
47+ def forward (self , x : Tensor ) -> Tensor :
48+ """
49+ Forward pass for patch embedding.
50+
51+ Args:
52+ x (Tensor): Input tensor of shape (B, C, H, W)
53+
54+ Returns:
55+ Tensor: Patch embeddings of shape (B, n_patches, embed_dim)
56+ """
57+ x = self .proj (x ) # (B, embed_dim, H//patch_size, W//patch_size)
58+ x = x .flatten (2 ) # (B, embed_dim, n_patches)
59+ x = x .transpose (1 , 2 ) # (B, n_patches, embed_dim)
60+ return x
61+
62+
63+ class MultiHeadSelfAttention (nn .Module ):
64+ """
65+ Multi-Head Self Attention (MSA) block as described in Equation 2 of ViT paper.
66+
67+ Args:
68+ embed_dim (int): Dimension of embedding
69+ num_heads (int): Number of attention heads
70+ dropout (float): Dropout rate
71+ """
72+
73+ def __init__ (self , embed_dim : int = 768 , num_heads : int = 12 , dropout : float = 0.0 ):
74+ super ().__init__ ()
75+ self .embed_dim = embed_dim
76+ self .num_heads = num_heads
77+ self .head_dim = embed_dim // num_heads
78+
79+ assert self .head_dim * num_heads == embed_dim , "embed_dim must be divisible by num_heads"
80+
81+ self .qkv = nn .Linear (embed_dim , embed_dim * 3 )
82+ self .attn_dropout = nn .Dropout (dropout )
83+ self .proj = nn .Linear (embed_dim , embed_dim )
84+ self .proj_dropout = nn .Dropout (dropout )
85+
86+ def forward (self , x : Tensor ) -> Tensor :
87+ """
88+ Forward pass for multi-head self attention.
89+
90+ Args:
91+ x (Tensor): Input tensor of shape (B, n_patches, embed_dim)
92+
93+ Returns:
94+ Tensor: Output tensor of same shape as input
95+ """
96+ B , N , C = x .shape
97+
98+ # Create Q, K, V
99+ qkv = self .qkv (x ).reshape (B , N , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
100+ q , k , v = qkv [0 ], qkv [1 ], qkv [2 ] # (B, num_heads, N, head_dim)
101+
102+ # Scaled dot-product attention
103+ attn = (q @ k .transpose (- 2 , - 1 )) * (self .head_dim ** - 0.5 ) # (B, num_heads, N, N)
104+ attn = F .softmax (attn , dim = - 1 )
105+ attn = self .attn_dropout (attn )
106+
107+ # Apply attention to values
108+ x = (attn @ v ).transpose (1 , 2 ).reshape (B , N , C ) # (B, N, embed_dim)
109+
110+ # Projection
111+ x = self .proj (x )
112+ x = self .proj_dropout (x )
113+
114+ return x
115+
116+
117+ class MLPBlock (nn .Module ):
118+ """
119+ Multilayer Perceptron (MLP) block as described in Equation 3 of ViT paper.
120+
121+ Args:
122+ embed_dim (int): Dimension of embedding
123+ mlp_ratio (float): Ratio of MLP hidden dimension to embed_dim
124+ dropout (float): Dropout rate
125+ """
126+
127+ def __init__ (self , embed_dim : int = 768 , mlp_ratio : float = 4.0 , dropout : float = 0.0 ):
128+ super ().__init__ ()
129+ hidden_dim = int (embed_dim * mlp_ratio )
130+
131+ self .fc1 = nn .Linear (embed_dim , hidden_dim )
132+ self .act = nn .GELU ()
133+ self .fc2 = nn .Linear (hidden_dim , embed_dim )
134+ self .dropout = nn .Dropout (dropout )
135+
136+ def forward (self , x : Tensor ) -> Tensor :
137+ """
138+ Forward pass for MLP block.
139+
140+ Args:
141+ x (Tensor): Input tensor of shape (B, n_patches, embed_dim)
142+
143+ Returns:
144+ Tensor: Output tensor of same shape as input
145+ """
146+ x = self .fc1 (x )
147+ x = self .act (x )
148+ x = self .dropout (x )
149+ x = self .fc2 (x )
150+ x = self .dropout (x )
151+ return x
152+
153+
154+ class TransformerEncoderBlock (nn .Module ):
155+ """
156+ Transformer Encoder Block combining MSA and MLP with residual connections.
157+
158+ Args:
159+ embed_dim (int): Dimension of embedding
160+ num_heads (int): Number of attention heads
161+ mlp_ratio (float): Ratio of MLP hidden dimension to embed_dim
162+ dropout (float): Dropout rate
163+ """
164+
165+ def __init__ (self , embed_dim : int = 768 , num_heads : int = 12 , mlp_ratio : float = 4.0 , dropout : float = 0.1 ):
166+ super ().__init__ ()
167+
168+ self .norm1 = nn .LayerNorm (embed_dim )
169+ self .attn = MultiHeadSelfAttention (embed_dim , num_heads , dropout )
170+ self .norm2 = nn .LayerNorm (embed_dim )
171+ self .mlp = MLPBlock (embed_dim , mlp_ratio , dropout )
172+
173+ def forward (self , x : Tensor ) -> Tensor :
174+ """
175+ Forward pass for transformer encoder block.
176+
177+ Args:
178+ x (Tensor): Input tensor of shape (B, n_patches, embed_dim)
179+
180+ Returns:
181+ Tensor: Output tensor of same shape as input
182+ """
183+ # Multi-head self attention with residual connection
184+ x = x + self .attn (self .norm1 (x ))
185+
186+ # MLP with residual connection
187+ x = x + self .mlp (self .norm2 (x ))
188+
189+ return x
190+
191+
192+ class VisionTransformer (nn .Module ):
193+ """
194+ Vision Transformer (ViT) model.
195+
196+ Args:
197+ img_size (int): Input image size
198+ patch_size (int): Patch size
199+ in_channels (int): Number of input channels
200+ num_classes (int): Number of output classes
201+ embed_dim (int): Embedding dimension
202+ depth (int): Number of transformer blocks
203+ num_heads (int): Number of attention heads
204+ mlp_ratio (float): Ratio of MLP hidden dimension to embed_dim
205+ dropout (float): Dropout rate
206+ emb_dropout (float): Embedding dropout rate
207+ """
208+
209+ def __init__ (
210+ self ,
211+ img_size : int = 224 ,
212+ patch_size : int = 16 ,
213+ in_channels : int = 3 ,
214+ num_classes : int = 1000 ,
215+ embed_dim : int = 768 ,
216+ depth : int = 12 ,
217+ num_heads : int = 12 ,
218+ mlp_ratio : float = 4.0 ,
219+ dropout : float = 0.1 ,
220+ emb_dropout : float = 0.1
221+ ):
222+ super ().__init__ ()
223+
224+ self .img_size = img_size
225+ self .patch_size = patch_size
226+ self .in_channels = in_channels
227+
228+ # Patch embedding
229+ self .patch_embed = PatchEmbedding (img_size , patch_size , in_channels , embed_dim )
230+ n_patches = self .patch_embed .n_patches
231+
232+ # Class token and position embedding
233+ self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim ))
234+ self .pos_embed = nn .Parameter (torch .zeros (1 , n_patches + 1 , embed_dim ))
235+ self .pos_dropout = nn .Dropout (emb_dropout )
236+
237+ # Transformer encoder blocks
238+ self .blocks = nn .ModuleList ([
239+ TransformerEncoderBlock (embed_dim , num_heads , mlp_ratio , dropout )
240+ for _ in range (depth )
241+ ])
242+
243+ # Layer normalization and classifier
244+ self .norm = nn .LayerNorm (embed_dim )
245+ self .head = nn .Linear (embed_dim , num_classes )
246+
247+ # Initialize weights
248+ self ._init_weights ()
249+
250+ def _init_weights (self ):
251+ """Initialize weights for the ViT model."""
252+ # Initialize patch embedding like a linear layer
253+ nn .init .xavier_uniform_ (self .patch_embed .proj .weight )
254+ if self .patch_embed .proj .bias is not None :
255+ nn .init .zeros_ (self .patch_embed .proj .bias )
256+
257+ # Initialize class token and position embedding
258+ nn .init .trunc_normal_ (self .cls_token , std = 0.02 )
259+ nn .init .trunc_normal_ (self .pos_embed , std = 0.02 )
260+
261+ # Initialize linear layers
262+ self .apply (self ._init_linear_weights )
263+
264+ def _init_linear_weights (self , module ):
265+ """Initialize weights for linear layers."""
266+ if isinstance (module , nn .Linear ):
267+ nn .init .trunc_normal_ (module .weight , std = 0.02 )
268+ if module .bias is not None :
269+ nn .init .zeros_ (module .bias )
270+
271+ def forward (self , x : Tensor ) -> Tensor :
272+ """
273+ Forward pass for Vision Transformer.
274+
275+ Args:
276+ x (Tensor): Input tensor of shape (B, C, H, W)
277+
278+ Returns:
279+ Tensor: Output logits of shape (B, num_classes)
280+ """
281+ B = x .shape [0 ]
282+
283+ # Create patch embeddings
284+ x = self .patch_embed (x ) # (B, n_patches, embed_dim)
285+
286+ # Add class token
287+ cls_tokens = self .cls_token .expand (B , - 1 , - 1 ) # (B, 1, embed_dim)
288+ x = torch .cat ((cls_tokens , x ), dim = 1 ) # (B, n_patches + 1, embed_dim)
289+
290+ # Add position embedding and apply dropout
291+ x = x + self .pos_embed
292+ x = self .pos_dropout (x )
293+
294+ # Apply transformer blocks
295+ for block in self .blocks :
296+ x = block (x )
297+
298+ # Apply final normalization and get class token output
299+ x = self .norm (x )
300+ cls_token_final = x [:, 0 ] # Use class token for classification
301+
302+ # Classifier
303+ x = self .head (cls_token_final )
304+
305+ return x
306+
307+
308+ def create_vit_model (
309+ img_size : int = 224 ,
310+ patch_size : int = 16 ,
311+ in_channels : int = 3 ,
312+ num_classes : int = 1000 ,
313+ embed_dim : int = 768 ,
314+ depth : int = 12 ,
315+ num_heads : int = 12 ,
316+ mlp_ratio : float = 4.0 ,
317+ dropout : float = 0.1 ,
318+ emb_dropout : float = 0.1
319+ ) -> VisionTransformer :
320+ """
321+ Factory function to create a Vision Transformer model.
322+
323+ Args:
324+ img_size (int): Input image size
325+ patch_size (int): Patch size
326+ in_channels (int): Number of input channels
327+ num_classes (int): Number of output classes
328+ embed_dim (int): Embedding dimension
329+ depth (int): Number of transformer blocks
330+ num_heads (int): Number of attention heads
331+ mlp_ratio (float): Ratio of MLP hidden dimension to embed_dim
332+ dropout (float): Dropout rate
333+ emb_dropout (float): Embedding dropout rate
334+
335+ Returns:
336+ VisionTransformer: Configured ViT model
337+ """
338+ return VisionTransformer (
339+ img_size = img_size ,
340+ patch_size = patch_size ,
341+ in_channels = in_channels ,
342+ num_classes = num_classes ,
343+ embed_dim = embed_dim ,
344+ depth = depth ,
345+ num_heads = num_heads ,
346+ mlp_ratio = mlp_ratio ,
347+ dropout = dropout ,
348+ emb_dropout = emb_dropout
349+ )
350+
351+
352+
353+
354+
355+ def count_parameters (model : nn .Module ) -> int :
356+ """
357+ Count the number of trainable parameters in a model.
358+
359+ Args:
360+ model (nn.Module): PyTorch model
361+
362+ Returns:
363+ int: Number of trainable parameters
364+ """
365+ return sum (p .numel () for p in model .parameters () if p .requires_grad )
366+
367+
368+ if __name__ == "__main__" :
369+ # Example usage
370+ model = create_vit_model (
371+ img_size = 224 ,
372+ patch_size = 16 ,
373+ num_classes = 3 , # pizza, steak, sushi
374+ embed_dim = 768 ,
375+ depth = 12 ,
376+ num_heads = 12
377+ )
378+
379+ print (f"Model created with { count_parameters (model ):,} parameters" )
380+
381+ # Test forward pass
382+ x = torch .randn (2 , 3 , 224 , 224 )
383+ out = model (x )
384+ print (f"Input shape: { x .shape } " )
385+ print (f"Output shape: { out .shape } " )
0 commit comments