Skip to content

Commit 141f3ca

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7dde477 commit 141f3ca

File tree

1 file changed

+30
-18
lines changed

1 file changed

+30
-18
lines changed

computer_vision/vision_tranformer.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ class PatchEmbedding(nn.Module):
2929
"""
3030

3131
def __init__(
32-
self, img_size: int = 224, patch_size: int = 16,
33-
in_channels: int = 3, embed_dim: int = 768
32+
self,
33+
img_size: int = 224,
34+
patch_size: int = 16,
35+
in_channels: int = 3,
36+
embed_dim: int = 768,
3437
):
3538
super().__init__()
3639
self.img_size = img_size
@@ -41,7 +44,7 @@ def __init__(
4144
in_channels=in_channels,
4245
out_channels=embed_dim,
4346
kernel_size=patch_size,
44-
stride=patch_size
47+
stride=patch_size,
4548
)
4649

4750
def forward(self, x: Tensor) -> Tensor:
@@ -106,7 +109,7 @@ def forward(self, x: Tensor) -> Tensor:
106109
q, k, v = qkv[0], qkv[1], qkv[2] # (B, num_heads, N, head_dim)
107110

108111
# Scaled dot-product attention
109-
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
112+
attn = (q @ k.transpose(-2, -1)) * (self.head_dim**-0.5)
110113
attn = functional.softmax(attn, dim=-1)
111114
attn = self.attn_dropout(attn)
112115

@@ -130,7 +133,9 @@ class MLPBlock(nn.Module):
130133
dropout (float): Dropout rate
131134
"""
132135

133-
def __init__(self, embed_dim: int = 768, mlp_ratio: float = 4.0, dropout: float = 0.0):
136+
def __init__(
137+
self, embed_dim: int = 768, mlp_ratio: float = 4.0, dropout: float = 0.0
138+
):
134139
super().__init__()
135140
hidden_dim = int(embed_dim * mlp_ratio)
136141

@@ -169,8 +174,11 @@ class TransformerEncoderBlock(nn.Module):
169174
"""
170175

171176
def __init__(
172-
self, embed_dim: int = 768, num_heads: int = 12,
173-
mlp_ratio: float = 4.0, dropout: float = 0.1
177+
self,
178+
embed_dim: int = 768,
179+
num_heads: int = 12,
180+
mlp_ratio: float = 4.0,
181+
dropout: float = 0.1,
174182
):
175183
super().__init__()
176184

@@ -226,7 +234,7 @@ def __init__(
226234
num_heads: int = 12,
227235
mlp_ratio: float = 4.0,
228236
dropout: float = 0.1,
229-
emb_dropout: float = 0.1
237+
emb_dropout: float = 0.1,
230238
):
231239
super().__init__()
232240

@@ -244,10 +252,12 @@ def __init__(
244252
self.pos_dropout = nn.Dropout(emb_dropout)
245253

246254
# Transformer encoder blocks
247-
self.blocks = nn.ModuleList([
248-
TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
249-
for _ in range(depth)
250-
])
255+
self.blocks = nn.ModuleList(
256+
[
257+
TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
258+
for _ in range(depth)
259+
]
260+
)
251261

252262
# Layer normalization and classifier
253263
self.norm = nn.LayerNorm(embed_dim)
@@ -324,7 +334,7 @@ def create_vit_model(
324334
num_heads: int = 12,
325335
mlp_ratio: float = 4.0,
326336
dropout: float = 0.1,
327-
emb_dropout: float = 0.1
337+
emb_dropout: float = 0.1,
328338
) -> VisionTransformer:
329339
"""
330340
Factory function to create a Vision Transformer model.
@@ -354,11 +364,13 @@ def create_vit_model(
354364
num_heads=num_heads,
355365
mlp_ratio=mlp_ratio,
356366
dropout=dropout,
357-
emb_dropout=emb_dropout
367+
emb_dropout=emb_dropout,
358368
)
359369

360370

361-
def get_pretrained_vit(model_name: str = "vit_base_patch16_224", num_classes: int = 1000) -> nn.Module:
371+
def get_pretrained_vit(
372+
model_name: str = "vit_base_patch16_224", num_classes: int = 1000
373+
) -> nn.Module:
362374
"""
363375
Load a pretrained ViT model from torchvision.
364376
@@ -376,9 +388,9 @@ def get_pretrained_vit(model_name: str = "vit_base_patch16_224", num_classes: in
376388
model = getattr(models, model_name)(pretrained=True)
377389
if num_classes != 1000:
378390
# Replace the head for fine-tuning
379-
if hasattr(model, 'heads'):
391+
if hasattr(model, "heads"):
380392
model.heads = nn.Linear(model.heads.in_features, num_classes)
381-
elif hasattr(model, 'head'):
393+
elif hasattr(model, "head"):
382394
model.head = nn.Linear(model.head.in_features, num_classes)
383395
return model
384396
else:
@@ -410,7 +422,7 @@ def count_parameters(model: nn.Module) -> int:
410422
num_classes=3, # pizza, steak, sushi
411423
embed_dim=768,
412424
depth=12,
413-
num_heads=12
425+
num_heads=12,
414426
)
415427

416428
print(f"Model created with {count_parameters(model):,} parameters")

0 commit comments

Comments
 (0)