Skip to content

Commit 551cb8c

Browse files
committed
updated the error
1 parent 1a9b0a0 commit 551cb8c

File tree

1 file changed

+385
-0
lines changed

1 file changed

+385
-0
lines changed
Lines changed: 385 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,385 @@
1+
"""
2+
Vision Transformer (ViT) Implementation
3+
4+
This module contains a PyTorch implementation of the Vision Transformer (ViT)
5+
architecture based on the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale".
6+
7+
Key Components:
8+
- Patch Embedding
9+
- Multi-Head Self Attention
10+
- MLP Block
11+
- Transformer Encoder
12+
- Vision Transformer Model
13+
"""
14+
15+
import torch
16+
import torch.nn as nn
17+
import torch.nn.functional as F
18+
from torch import Tensor
19+
from typing import Optional, Tuple
20+
import math
21+
22+
23+
class PatchEmbedding(nn.Module):
24+
"""
25+
Creates patch embeddings from input images as described in Equation 1 of ViT paper.
26+
27+
Args:
28+
img_size (int): Size of input image (assumed square)
29+
patch_size (int): Size of each patch (assumed square)
30+
in_channels (int): Number of input channels
31+
embed_dim (int): Dimension of embedding
32+
"""
33+
34+
def __init__(self, img_size: int = 224, patch_size: int = 16, in_channels: int = 3, embed_dim: int = 768):
35+
super().__init__()
36+
self.img_size = img_size
37+
self.patch_size = patch_size
38+
self.n_patches = (img_size // patch_size) ** 2
39+
40+
self.proj = nn.Conv2d(
41+
in_channels=in_channels,
42+
out_channels=embed_dim,
43+
kernel_size=patch_size,
44+
stride=patch_size
45+
)
46+
47+
def forward(self, x: Tensor) -> Tensor:
48+
"""
49+
Forward pass for patch embedding.
50+
51+
Args:
52+
x (Tensor): Input tensor of shape (B, C, H, W)
53+
54+
Returns:
55+
Tensor: Patch embeddings of shape (B, n_patches, embed_dim)
56+
"""
57+
x = self.proj(x) # (B, embed_dim, H//patch_size, W//patch_size)
58+
x = x.flatten(2) # (B, embed_dim, n_patches)
59+
x = x.transpose(1, 2) # (B, n_patches, embed_dim)
60+
return x
61+
62+
63+
class MultiHeadSelfAttention(nn.Module):
64+
"""
65+
Multi-Head Self Attention (MSA) block as described in Equation 2 of ViT paper.
66+
67+
Args:
68+
embed_dim (int): Dimension of embedding
69+
num_heads (int): Number of attention heads
70+
dropout (float): Dropout rate
71+
"""
72+
73+
def __init__(self, embed_dim: int = 768, num_heads: int = 12, dropout: float = 0.0):
74+
super().__init__()
75+
self.embed_dim = embed_dim
76+
self.num_heads = num_heads
77+
self.head_dim = embed_dim // num_heads
78+
79+
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
80+
81+
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
82+
self.attn_dropout = nn.Dropout(dropout)
83+
self.proj = nn.Linear(embed_dim, embed_dim)
84+
self.proj_dropout = nn.Dropout(dropout)
85+
86+
def forward(self, x: Tensor) -> Tensor:
87+
"""
88+
Forward pass for multi-head self attention.
89+
90+
Args:
91+
x (Tensor): Input tensor of shape (B, n_patches, embed_dim)
92+
93+
Returns:
94+
Tensor: Output tensor of same shape as input
95+
"""
96+
B, N, C = x.shape
97+
98+
# Create Q, K, V
99+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
100+
q, k, v = qkv[0], qkv[1], qkv[2] # (B, num_heads, N, head_dim)
101+
102+
# Scaled dot-product attention
103+
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) # (B, num_heads, N, N)
104+
attn = F.softmax(attn, dim=-1)
105+
attn = self.attn_dropout(attn)
106+
107+
# Apply attention to values
108+
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, embed_dim)
109+
110+
# Projection
111+
x = self.proj(x)
112+
x = self.proj_dropout(x)
113+
114+
return x
115+
116+
117+
class MLPBlock(nn.Module):
118+
"""
119+
Multilayer Perceptron (MLP) block as described in Equation 3 of ViT paper.
120+
121+
Args:
122+
embed_dim (int): Dimension of embedding
123+
mlp_ratio (float): Ratio of MLP hidden dimension to embed_dim
124+
dropout (float): Dropout rate
125+
"""
126+
127+
def __init__(self, embed_dim: int = 768, mlp_ratio: float = 4.0, dropout: float = 0.0):
128+
super().__init__()
129+
hidden_dim = int(embed_dim * mlp_ratio)
130+
131+
self.fc1 = nn.Linear(embed_dim, hidden_dim)
132+
self.act = nn.GELU()
133+
self.fc2 = nn.Linear(hidden_dim, embed_dim)
134+
self.dropout = nn.Dropout(dropout)
135+
136+
def forward(self, x: Tensor) -> Tensor:
137+
"""
138+
Forward pass for MLP block.
139+
140+
Args:
141+
x (Tensor): Input tensor of shape (B, n_patches, embed_dim)
142+
143+
Returns:
144+
Tensor: Output tensor of same shape as input
145+
"""
146+
x = self.fc1(x)
147+
x = self.act(x)
148+
x = self.dropout(x)
149+
x = self.fc2(x)
150+
x = self.dropout(x)
151+
return x
152+
153+
154+
class TransformerEncoderBlock(nn.Module):
155+
"""
156+
Transformer Encoder Block combining MSA and MLP with residual connections.
157+
158+
Args:
159+
embed_dim (int): Dimension of embedding
160+
num_heads (int): Number of attention heads
161+
mlp_ratio (float): Ratio of MLP hidden dimension to embed_dim
162+
dropout (float): Dropout rate
163+
"""
164+
165+
def __init__(self, embed_dim: int = 768, num_heads: int = 12, mlp_ratio: float = 4.0, dropout: float = 0.1):
166+
super().__init__()
167+
168+
self.norm1 = nn.LayerNorm(embed_dim)
169+
self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
170+
self.norm2 = nn.LayerNorm(embed_dim)
171+
self.mlp = MLPBlock(embed_dim, mlp_ratio, dropout)
172+
173+
def forward(self, x: Tensor) -> Tensor:
174+
"""
175+
Forward pass for transformer encoder block.
176+
177+
Args:
178+
x (Tensor): Input tensor of shape (B, n_patches, embed_dim)
179+
180+
Returns:
181+
Tensor: Output tensor of same shape as input
182+
"""
183+
# Multi-head self attention with residual connection
184+
x = x + self.attn(self.norm1(x))
185+
186+
# MLP with residual connection
187+
x = x + self.mlp(self.norm2(x))
188+
189+
return x
190+
191+
192+
class VisionTransformer(nn.Module):
193+
"""
194+
Vision Transformer (ViT) model.
195+
196+
Args:
197+
img_size (int): Input image size
198+
patch_size (int): Patch size
199+
in_channels (int): Number of input channels
200+
num_classes (int): Number of output classes
201+
embed_dim (int): Embedding dimension
202+
depth (int): Number of transformer blocks
203+
num_heads (int): Number of attention heads
204+
mlp_ratio (float): Ratio of MLP hidden dimension to embed_dim
205+
dropout (float): Dropout rate
206+
emb_dropout (float): Embedding dropout rate
207+
"""
208+
209+
def __init__(
210+
self,
211+
img_size: int = 224,
212+
patch_size: int = 16,
213+
in_channels: int = 3,
214+
num_classes: int = 1000,
215+
embed_dim: int = 768,
216+
depth: int = 12,
217+
num_heads: int = 12,
218+
mlp_ratio: float = 4.0,
219+
dropout: float = 0.1,
220+
emb_dropout: float = 0.1
221+
):
222+
super().__init__()
223+
224+
self.img_size = img_size
225+
self.patch_size = patch_size
226+
self.in_channels = in_channels
227+
228+
# Patch embedding
229+
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
230+
n_patches = self.patch_embed.n_patches
231+
232+
# Class token and position embedding
233+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
234+
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
235+
self.pos_dropout = nn.Dropout(emb_dropout)
236+
237+
# Transformer encoder blocks
238+
self.blocks = nn.ModuleList([
239+
TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
240+
for _ in range(depth)
241+
])
242+
243+
# Layer normalization and classifier
244+
self.norm = nn.LayerNorm(embed_dim)
245+
self.head = nn.Linear(embed_dim, num_classes)
246+
247+
# Initialize weights
248+
self._init_weights()
249+
250+
def _init_weights(self):
251+
"""Initialize weights for the ViT model."""
252+
# Initialize patch embedding like a linear layer
253+
nn.init.xavier_uniform_(self.patch_embed.proj.weight)
254+
if self.patch_embed.proj.bias is not None:
255+
nn.init.zeros_(self.patch_embed.proj.bias)
256+
257+
# Initialize class token and position embedding
258+
nn.init.trunc_normal_(self.cls_token, std=0.02)
259+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
260+
261+
# Initialize linear layers
262+
self.apply(self._init_linear_weights)
263+
264+
def _init_linear_weights(self, module):
265+
"""Initialize weights for linear layers."""
266+
if isinstance(module, nn.Linear):
267+
nn.init.trunc_normal_(module.weight, std=0.02)
268+
if module.bias is not None:
269+
nn.init.zeros_(module.bias)
270+
271+
def forward(self, x: Tensor) -> Tensor:
272+
"""
273+
Forward pass for Vision Transformer.
274+
275+
Args:
276+
x (Tensor): Input tensor of shape (B, C, H, W)
277+
278+
Returns:
279+
Tensor: Output logits of shape (B, num_classes)
280+
"""
281+
B = x.shape[0]
282+
283+
# Create patch embeddings
284+
x = self.patch_embed(x) # (B, n_patches, embed_dim)
285+
286+
# Add class token
287+
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, embed_dim)
288+
x = torch.cat((cls_tokens, x), dim=1) # (B, n_patches + 1, embed_dim)
289+
290+
# Add position embedding and apply dropout
291+
x = x + self.pos_embed
292+
x = self.pos_dropout(x)
293+
294+
# Apply transformer blocks
295+
for block in self.blocks:
296+
x = block(x)
297+
298+
# Apply final normalization and get class token output
299+
x = self.norm(x)
300+
cls_token_final = x[:, 0] # Use class token for classification
301+
302+
# Classifier
303+
x = self.head(cls_token_final)
304+
305+
return x
306+
307+
308+
def create_vit_model(
309+
img_size: int = 224,
310+
patch_size: int = 16,
311+
in_channels: int = 3,
312+
num_classes: int = 1000,
313+
embed_dim: int = 768,
314+
depth: int = 12,
315+
num_heads: int = 12,
316+
mlp_ratio: float = 4.0,
317+
dropout: float = 0.1,
318+
emb_dropout: float = 0.1
319+
) -> VisionTransformer:
320+
"""
321+
Factory function to create a Vision Transformer model.
322+
323+
Args:
324+
img_size (int): Input image size
325+
patch_size (int): Patch size
326+
in_channels (int): Number of input channels
327+
num_classes (int): Number of output classes
328+
embed_dim (int): Embedding dimension
329+
depth (int): Number of transformer blocks
330+
num_heads (int): Number of attention heads
331+
mlp_ratio (float): Ratio of MLP hidden dimension to embed_dim
332+
dropout (float): Dropout rate
333+
emb_dropout (float): Embedding dropout rate
334+
335+
Returns:
336+
VisionTransformer: Configured ViT model
337+
"""
338+
return VisionTransformer(
339+
img_size=img_size,
340+
patch_size=patch_size,
341+
in_channels=in_channels,
342+
num_classes=num_classes,
343+
embed_dim=embed_dim,
344+
depth=depth,
345+
num_heads=num_heads,
346+
mlp_ratio=mlp_ratio,
347+
dropout=dropout,
348+
emb_dropout=emb_dropout
349+
)
350+
351+
352+
353+
354+
355+
def count_parameters(model: nn.Module) -> int:
356+
"""
357+
Count the number of trainable parameters in a model.
358+
359+
Args:
360+
model (nn.Module): PyTorch model
361+
362+
Returns:
363+
int: Number of trainable parameters
364+
"""
365+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
366+
367+
368+
if __name__ == "__main__":
369+
# Example usage
370+
model = create_vit_model(
371+
img_size=224,
372+
patch_size=16,
373+
num_classes=3, # pizza, steak, sushi
374+
embed_dim=768,
375+
depth=12,
376+
num_heads=12
377+
)
378+
379+
print(f"Model created with {count_parameters(model):,} parameters")
380+
381+
# Test forward pass
382+
x = torch.randn(2, 3, 224, 224)
383+
out = model(x)
384+
print(f"Input shape: {x.shape}")
385+
print(f"Output shape: {out.shape}")

0 commit comments

Comments
 (0)