A PyTorch implementation of knowledge distillation for brain tumor classification, where a lightweight MobileNetV3-Small student model learns from a powerful ResNet-101 teacher model.
This project demonstrates how knowledge distillation can compress a large neural network into a smaller, more efficient model while maintaining high accuracy. The system classifies brain MRI images into four categories:
- Glioma
- Meningioma
- No Tumor
- Pituitary Tumor
- π§ Teacher Model: ResNet-101 (pre-trained on ImageNet)
- π± Student Model: MobileNetV3-Small (pre-trained on ImageNet)
- π₯ Knowledge Distillation: Temperature-scaled soft targets + hard loss
- π¨ Visualization: Grad-CAM heatmaps to compare teacher vs student attention
- π High Performance: Student achieves 98.7% accuracy on the test set
| Model | Accuracy | F1 Score |
|---|---|---|
| Teacher (ResNet-101) | 99.30% | - |
| Student (MobileNetV3-Small) | 98.70% | 0.986 |
| Glioma | Meningioma | NoTumor | Pituitary | |
|---|---|---|---|---|
| Glioma | 288 | 11 | 0 | 1 |
| Meningioma | 1 | 303 | 2 | 0 |
| NoTumor | 0 | 0 | 405 | 0 |
| Pituitary | 0 | 2 | 0 | 298 |
pip install torch torchvision
pip install scikit-learn matplotlib seaborn tqdm grad-cambrain-tumor-mri-dataset/
βββ Training/
β βββ glioma/
β βββ meningioma/
β βββ notumor/
β βββ pituitary/
βββ Testing/
βββ glioma/
βββ meningioma/
βββ notumor/
βββ pituitary/
python train_teacher(epochs=10)python train_student(epochs=10)python evaluate(student, test_loader)The distillation loss combines hard and soft components:
loss = Ξ± * hard_loss + (1 - Ξ±) * soft_loss- Hard Loss: Cross-entropy with true labels
- Soft Loss: KL divergence between student and teacher predictions
- Temperature (T=4): Softens probability distributions
- Alpha (Ξ±=0.5): Balances hard and soft losses
- Random horizontal flip
- Random rotation (Β±15Β°)
- Resize to 224Γ224
- ImageNet normalization
The project includes visualization tools to better understand model performance:
- π Prediction Comparison: Bar charts for teacher vs student confidence
- π Confusion Matrix: Heatmap of classification performance
- π₯ Grad-CAM: Visualizes attention regions in MRI scans
Example usage:
visualize_classification_kd(teacher, student, val_dataset, idx=1000, device='cuda')- Depth: 101 layers
- Parameters: ~44M
- Final layer modified for 4-class classification
- Lightweight architecture
- Parameters: ~2.5M (94% reduction)
- Final layer modified for 4-class classification
Batch Size: 32
Learning Rate: 1e-4
Optimizer: Adam
Teacher Epochs: 10
Student Epochs: 10
Validation Split: 20%train_teacher()β Train the ResNet-101 teacher modeltrain_student()β Train the student using KDkd_loss()β Compute combined distillation lossevaluate()β Calculate model accuracyvisualize_classification_kd()β Visualize predictions with Grad-CAM
If you use this code in your research, please cite:
@article{hinton2015distilling,
title={Distilling the knowledge in a neural network},
author={Hinton, Geoffrey and Vinyals, Oriol and Dean, Jeff},
journal={arXiv preprint arXiv:1503.02531},
year={2015}
}This project is available for educational and research purposes under the MIT License.
- 𧬠Brain Tumor MRI Dataset from Kaggle
- π§ PyTorch and torchvision teams
- π₯ Grad-CAM implementation from pytorch-grad-cam
Contributions, issues, and feature requests are welcome! Feel free to check the issues page.