Skip to content

Commit 62af08b

Browse files
committed
Update Vision Transformer (ViT) demo with full implementation
1 parent 5fa4b88 commit 62af08b

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed
Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,45 @@
1-
def vision_transformer_demo():
2-
print("Vision Transformer running...")
1+
"""
2+
Vision Transformer (ViT)
3+
========================
4+
5+
This module demonstrates how to use a pretrained Vision Transformer (ViT)
6+
for image classification using Hugging Face's Transformers library.
7+
8+
Source:
9+
https://huggingface.co/docs/transformers/model_doc/vit
10+
"""
11+
12+
from transformers import ViTImageProcessor, ViTForImageClassification
13+
from PIL import Image
14+
import requests
15+
import torch
16+
17+
18+
def vision_transformer_demo() -> None:
19+
"""
20+
Demonstrates Vision Transformer (ViT) on a sample image.
21+
22+
Example:
23+
>>> vision_transformer_demo() # doctest: +SKIP
24+
Predicted label: tabby, tabby cat
25+
"""
26+
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cat_sample.jpeg"
27+
image = Image.open(requests.get(url, stream=True).raw)
28+
29+
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
30+
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
31+
32+
inputs = processor(images=image, return_tensors="pt")
33+
34+
with torch.no_grad():
35+
outputs = model(**inputs)
36+
logits = outputs.logits
37+
38+
predicted_class_idx = logits.argmax(-1).item()
39+
predicted_label = model.config.id2label[predicted_class_idx]
40+
41+
print(f"Predicted label: {predicted_label}")
42+
43+
44+
if __name__ == "__main__":
45+
vision_transformer_demo()

0 commit comments

Comments
 (0)