Skip to content

Commit 6342487

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 96143f7 commit 6342487

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

computer_vision/vision_tranformer.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@ class PatchEmbedding(nn.Module):
3030
"""
3131

3232
def __init__(
33-
self, img_size: int = 224, patch_size: int = 16,
34-
in_channels: int = 3, embed_dim: int = 768
33+
self,
34+
img_size: int = 224,
35+
patch_size: int = 16,
36+
in_channels: int = 3,
37+
embed_dim: int = 768,
3538
):
3639
super().__init__()
3740
self.img_size = img_size
@@ -42,7 +45,7 @@ def __init__(
4245
in_channels=in_channels,
4346
out_channels=embed_dim,
4447
kernel_size=patch_size,
45-
stride=patch_size
48+
stride=patch_size,
4649
)
4750

4851
def forward(self, x: Tensor) -> Tensor:
@@ -107,7 +110,7 @@ def forward(self, x: Tensor) -> Tensor:
107110
q, k, v = qkv[0], qkv[1], qkv[2] # (B, num_heads, N, head_dim)
108111

109112
# Scaled dot-product attention
110-
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
113+
attn = (q @ k.transpose(-2, -1)) * (self.head_dim**-0.5)
111114
attn = functional.softmax(attn, dim=-1)
112115
attn = self.attn_dropout(attn)
113116

@@ -131,7 +134,9 @@ class MLPBlock(nn.Module):
131134
dropout (float): Dropout rate
132135
"""
133136

134-
def __init__(self, embed_dim: int = 768, mlp_ratio: float = 4.0, dropout: float = 0.0):
137+
def __init__(
138+
self, embed_dim: int = 768, mlp_ratio: float = 4.0, dropout: float = 0.0
139+
):
135140
super().__init__()
136141
hidden_dim = int(embed_dim * mlp_ratio)
137142

@@ -170,8 +175,11 @@ class TransformerEncoderBlock(nn.Module):
170175
"""
171176

172177
def __init__(
173-
self, embed_dim: int = 768, num_heads: int = 12,
174-
mlp_ratio: float = 4.0, dropout: float = 0.1
178+
self,
179+
embed_dim: int = 768,
180+
num_heads: int = 12,
181+
mlp_ratio: float = 4.0,
182+
dropout: float = 0.1,
175183
):
176184
super().__init__()
177185

@@ -227,7 +235,7 @@ def __init__(
227235
num_heads: int = 12,
228236
mlp_ratio: float = 4.0,
229237
dropout: float = 0.1,
230-
emb_dropout: float = 0.1
238+
emb_dropout: float = 0.1,
231239
):
232240
super().__init__()
233241

@@ -245,10 +253,12 @@ def __init__(
245253
self.pos_dropout = nn.Dropout(emb_dropout)
246254

247255
# Transformer encoder blocks
248-
self.blocks = nn.ModuleList([
249-
TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
250-
for _ in range(depth)
251-
])
256+
self.blocks = nn.ModuleList(
257+
[
258+
TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
259+
for _ in range(depth)
260+
]
261+
)
252262

253263
# Layer normalization and classifier
254264
self.norm = nn.LayerNorm(embed_dim)
@@ -325,7 +335,7 @@ def create_vit_model(
325335
num_heads: int = 12,
326336
mlp_ratio: float = 4.0,
327337
dropout: float = 0.1,
328-
emb_dropout: float = 0.1
338+
emb_dropout: float = 0.1,
329339
) -> VisionTransformer:
330340
"""
331341
Factory function to create a Vision Transformer model.
@@ -355,11 +365,13 @@ def create_vit_model(
355365
num_heads=num_heads,
356366
mlp_ratio=mlp_ratio,
357367
dropout=dropout,
358-
emb_dropout=emb_dropout
368+
emb_dropout=emb_dropout,
359369
)
360370

361371

362-
def get_pretrained_vit(model_name: str = "vit_base_patch16_224", num_classes: int = 1000) -> nn.Module:
372+
def get_pretrained_vit(
373+
model_name: str = "vit_base_patch16_224", num_classes: int = 1000
374+
) -> nn.Module:
363375
"""
364376
Load a pretrained ViT model from torchvision.
365377
@@ -377,9 +389,9 @@ def get_pretrained_vit(model_name: str = "vit_base_patch16_224", num_classes: in
377389
model = getattr(models, model_name)(pretrained=True)
378390
if num_classes != 1000:
379391
# Replace the head for fine-tuning
380-
if hasattr(model, 'heads'):
392+
if hasattr(model, "heads"):
381393
model.heads = nn.Linear(model.heads.in_features, num_classes)
382-
elif hasattr(model, 'head'):
394+
elif hasattr(model, "head"):
383395
model.head = nn.Linear(model.head.in_features, num_classes)
384396
return model
385397
else:
@@ -411,7 +423,7 @@ def count_parameters(model: nn.Module) -> int:
411423
num_classes=3, # pizza, steak, sushi
412424
embed_dim=768,
413425
depth=12,
414-
num_heads=12
426+
num_heads=12,
415427
)
416428

417429
print(f"Model created with {count_parameters(model):,} parameters")

0 commit comments

Comments
 (0)