Skip to content

Commit 7dde477

Browse files
authored
Update vision_tranformer.py
1 parent 224ec88 commit 7dde477

File tree

1 file changed

+36
-50
lines changed

1 file changed

+36
-50
lines changed

computer_vision/vision_tranformer.py

Lines changed: 36 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""
2-
Vision Transformer (ViT) Implementation
2+
Vision Transformer (ViT) Implementation.
33
44
This module contains a PyTorch implementation of the Vision Transformer (ViT)
5-
architecture based on the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale".
5+
architecture based on the paper "An Image is Worth 16x16 Words:
6+
Transformers for Image Recognition at Scale".
67
78
Key Components:
89
- Patch Embedding
@@ -12,17 +13,13 @@
1213
- Vision Transformer Model
1314
"""
1415

15-
import torch
16-
import torch.nn as nn
17-
import torch.nn.functional as F
18-
from torch import Tensor
19-
from typing import Optional, Tuple
20-
import math
16+
from torch import Tensor, nn
17+
import torch.nn.functional as functional
2118

2219

2320
class PatchEmbedding(nn.Module):
2421
"""
25-
Creates patch embeddings from input images as described in Equation 1 of ViT paper.
22+
Creates patch embeddings from input images as described in Equation 1.
2623
2724
Args:
2825
img_size (int): Size of input image (assumed square)
@@ -32,11 +29,8 @@ class PatchEmbedding(nn.Module):
3229
"""
3330

3431
def __init__(
35-
self,
36-
img_size: int = 224,
37-
patch_size: int = 16,
38-
in_channels: int = 3,
39-
embed_dim: int = 768,
32+
self, img_size: int = 224, patch_size: int = 16,
33+
in_channels: int = 3, embed_dim: int = 768
4034
):
4135
super().__init__()
4236
self.img_size = img_size
@@ -47,7 +41,7 @@ def __init__(
4741
in_channels=in_channels,
4842
out_channels=embed_dim,
4943
kernel_size=patch_size,
50-
stride=patch_size,
44+
stride=patch_size
5145
)
5246

5347
def forward(self, x: Tensor) -> Tensor:
@@ -68,7 +62,7 @@ def forward(self, x: Tensor) -> Tensor:
6862

6963
class MultiHeadSelfAttention(nn.Module):
7064
"""
71-
Multi-Head Self Attention (MSA) block as described in Equation 2 of ViT paper.
65+
Multi-Head Self Attention (MSA) block as described in Equation 2.
7266
7367
Args:
7468
embed_dim (int): Dimension of embedding
@@ -101,23 +95,23 @@ def forward(self, x: Tensor) -> Tensor:
10195
Returns:
10296
Tensor: Output tensor of same shape as input
10397
"""
104-
B, N, C = x.shape
98+
batch_size, num_patches, channels = x.shape
10599

106100
# Create Q, K, V
107101
qkv = (
108102
self.qkv(x)
109-
.reshape(B, N, 3, self.num_heads, self.head_dim)
103+
.reshape(batch_size, num_patches, 3, self.num_heads, self.head_dim)
110104
.permute(2, 0, 3, 1, 4)
111105
)
112106
q, k, v = qkv[0], qkv[1], qkv[2] # (B, num_heads, N, head_dim)
113107

114108
# Scaled dot-product attention
115-
attn = (q @ k.transpose(-2, -1)) * (self.head_dim**-0.5) # (B, num_heads, N, N)
116-
attn = F.softmax(attn, dim=-1)
109+
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
110+
attn = functional.softmax(attn, dim=-1)
117111
attn = self.attn_dropout(attn)
118112

119113
# Apply attention to values
120-
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, embed_dim)
114+
x = (attn @ v).transpose(1, 2).reshape(batch_size, num_patches, channels)
121115

122116
# Projection
123117
x = self.proj(x)
@@ -128,17 +122,15 @@ def forward(self, x: Tensor) -> Tensor:
128122

129123
class MLPBlock(nn.Module):
130124
"""
131-
Multilayer Perceptron (MLP) block as described in Equation 3 of ViT paper.
125+
Multilayer Perceptron (MLP) block as described in Equation 3.
132126
133127
Args:
134128
embed_dim (int): Dimension of embedding
135129
mlp_ratio (float): Ratio of MLP hidden dimension to embed_dim
136130
dropout (float): Dropout rate
137131
"""
138132

139-
def __init__(
140-
self, embed_dim: int = 768, mlp_ratio: float = 4.0, dropout: float = 0.0
141-
):
133+
def __init__(self, embed_dim: int = 768, mlp_ratio: float = 4.0, dropout: float = 0.0):
142134
super().__init__()
143135
hidden_dim = int(embed_dim * mlp_ratio)
144136

@@ -177,11 +169,8 @@ class TransformerEncoderBlock(nn.Module):
177169
"""
178170

179171
def __init__(
180-
self,
181-
embed_dim: int = 768,
182-
num_heads: int = 12,
183-
mlp_ratio: float = 4.0,
184-
dropout: float = 0.1,
172+
self, embed_dim: int = 768, num_heads: int = 12,
173+
mlp_ratio: float = 4.0, dropout: float = 0.1
185174
):
186175
super().__init__()
187176

@@ -237,7 +226,7 @@ def __init__(
237226
num_heads: int = 12,
238227
mlp_ratio: float = 4.0,
239228
dropout: float = 0.1,
240-
emb_dropout: float = 0.1,
229+
emb_dropout: float = 0.1
241230
):
242231
super().__init__()
243232

@@ -255,12 +244,10 @@ def __init__(
255244
self.pos_dropout = nn.Dropout(emb_dropout)
256245

257246
# Transformer encoder blocks
258-
self.blocks = nn.ModuleList(
259-
[
260-
TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
261-
for _ in range(depth)
262-
]
263-
)
247+
self.blocks = nn.ModuleList([
248+
TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
249+
for _ in range(depth)
250+
])
264251

265252
# Layer normalization and classifier
266253
self.norm = nn.LayerNorm(embed_dim)
@@ -300,14 +287,14 @@ def forward(self, x: Tensor) -> Tensor:
300287
Returns:
301288
Tensor: Output logits of shape (B, num_classes)
302289
"""
303-
B = x.shape[0]
290+
batch_size = x.shape[0]
304291

305292
# Create patch embeddings
306293
x = self.patch_embed(x) # (B, n_patches, embed_dim)
307294

308295
# Add class token
309-
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, embed_dim)
310-
x = torch.cat((cls_tokens, x), dim=1) # (B, n_patches + 1, embed_dim)
296+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
297+
x = torch.cat((cls_tokens, x), dim=1)
311298

312299
# Add position embedding and apply dropout
313300
x = x + self.pos_embed
@@ -337,7 +324,7 @@ def create_vit_model(
337324
num_heads: int = 12,
338325
mlp_ratio: float = 4.0,
339326
dropout: float = 0.1,
340-
emb_dropout: float = 0.1,
327+
emb_dropout: float = 0.1
341328
) -> VisionTransformer:
342329
"""
343330
Factory function to create a Vision Transformer model.
@@ -367,13 +354,11 @@ def create_vit_model(
367354
num_heads=num_heads,
368355
mlp_ratio=mlp_ratio,
369356
dropout=dropout,
370-
emb_dropout=emb_dropout,
357+
emb_dropout=emb_dropout
371358
)
372359

373360

374-
def get_pretrained_vit(
375-
model_name: str = "vit_base_patch16_224", num_classes: int = 1000
376-
) -> nn.Module:
361+
def get_pretrained_vit(model_name: str = "vit_base_patch16_224", num_classes: int = 1000) -> nn.Module:
377362
"""
378363
Load a pretrained ViT model from torchvision.
379364
@@ -385,19 +370,20 @@ def get_pretrained_vit(
385370
nn.Module: Pretrained ViT model
386371
"""
387372
try:
388-
import torchvision.models as models
373+
from torchvision import models
389374

390375
if hasattr(models, model_name):
391376
model = getattr(models, model_name)(pretrained=True)
392377
if num_classes != 1000:
393378
# Replace the head for fine-tuning
394-
if hasattr(model, "heads"):
379+
if hasattr(model, 'heads'):
395380
model.heads = nn.Linear(model.heads.in_features, num_classes)
396-
elif hasattr(model, "head"):
381+
elif hasattr(model, 'head'):
397382
model.head = nn.Linear(model.head.in_features, num_classes)
398383
return model
399384
else:
400-
raise ValueError(f"Model {model_name} not found in torchvision.models")
385+
error_msg = f"Model {model_name} not found in torchvision.models"
386+
raise ValueError(error_msg)
401387

402388
except ImportError:
403389
raise ImportError("torchvision is required to load pretrained models")
@@ -424,7 +410,7 @@ def count_parameters(model: nn.Module) -> int:
424410
num_classes=3, # pizza, steak, sushi
425411
embed_dim=768,
426412
depth=12,
427-
num_heads=12,
413+
num_heads=12
428414
)
429415

430416
print(f"Model created with {count_parameters(model):,} parameters")

0 commit comments

Comments
 (0)