Skip to content

utpalbarua/Knowledge-Distillation-for-Brain-Tumor-Classification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

3 Commits
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

🧠 Brain Tumor MRI Classification using Knowledge Distillation

A PyTorch implementation of knowledge distillation for brain tumor classification, where a lightweight MobileNetV3-Small student model learns from a powerful ResNet-101 teacher model.


πŸ“‹ Overview

This project demonstrates how knowledge distillation can compress a large neural network into a smaller, more efficient model while maintaining high accuracy. The system classifies brain MRI images into four categories:

  • Glioma
  • Meningioma
  • No Tumor
  • Pituitary Tumor

✨ Key Features

  • 🧠 Teacher Model: ResNet-101 (pre-trained on ImageNet)
  • πŸ“± Student Model: MobileNetV3-Small (pre-trained on ImageNet)
  • πŸ”₯ Knowledge Distillation: Temperature-scaled soft targets + hard loss
  • 🎨 Visualization: Grad-CAM heatmaps to compare teacher vs student attention
  • πŸ“ˆ High Performance: Student achieves 98.7% accuracy on the test set

🎯 Results

Model Accuracy F1 Score
Teacher (ResNet-101) 99.30% -
Student (MobileNetV3-Small) 98.70% 0.986

πŸ“‰ Confusion Matrix

Glioma Meningioma NoTumor Pituitary
Glioma 288 11 0 1
Meningioma 1 303 2 0
NoTumor 0 0 405 0
Pituitary 0 2 0 298

πŸš€ Getting Started

πŸ“¦ Prerequisites

pip install torch torchvision
pip install scikit-learn matplotlib seaborn tqdm grad-cam

πŸ“ Dataset Structure

brain-tumor-mri-dataset/
β”œβ”€β”€ Training/
β”‚   β”œβ”€β”€ glioma/
β”‚   β”œβ”€β”€ meningioma/
β”‚   β”œβ”€β”€ notumor/
β”‚   └── pituitary/
└── Testing/
    β”œβ”€β”€ glioma/
    β”œβ”€β”€ meningioma/
    β”œβ”€β”€ notumor/
    └── pituitary/

πŸ—οΈ Training

Train the Teacher Model:

python train_teacher(epochs=10)

Train the Student Model with Knowledge Distillation:

python train_student(epochs=10)

Evaluate the Student Model:

python evaluate(student, test_loader)

πŸ”¬ Methodology

πŸ“˜ Knowledge Distillation Loss

The distillation loss combines hard and soft components:

loss = Ξ± * hard_loss + (1 - Ξ±) * soft_loss
  • Hard Loss: Cross-entropy with true labels
  • Soft Loss: KL divergence between student and teacher predictions
  • Temperature (T=4): Softens probability distributions
  • Alpha (Ξ±=0.5): Balances hard and soft losses

πŸ§ͺ Data Augmentation

  • Random horizontal flip
  • Random rotation (Β±15Β°)
  • Resize to 224Γ—224
  • ImageNet normalization

πŸ“Š Visualizations

The project includes visualization tools to better understand model performance:

  • πŸ“ˆ Prediction Comparison: Bar charts for teacher vs student confidence
  • πŸ“‰ Confusion Matrix: Heatmap of classification performance
  • πŸ”₯ Grad-CAM: Visualizes attention regions in MRI scans

Example usage:

visualize_classification_kd(teacher, student, val_dataset, idx=1000, device='cuda')

πŸ—οΈ Architecture Details

🧠 Teacher Network (ResNet-101)

  • Depth: 101 layers
  • Parameters: ~44M
  • Final layer modified for 4-class classification

πŸ“± Student Network (MobileNetV3-Small)

  • Lightweight architecture
  • Parameters: ~2.5M (94% reduction)
  • Final layer modified for 4-class classification

πŸ“ˆ Training Configuration

Batch Size: 32
Learning Rate: 1e-4
Optimizer: Adam
Teacher Epochs: 10
Student Epochs: 10
Validation Split: 20%

πŸ” Key Functions

  • train_teacher() – Train the ResNet-101 teacher model
  • train_student() – Train the student using KD
  • kd_loss() – Compute combined distillation loss
  • evaluate() – Calculate model accuracy
  • visualize_classification_kd() – Visualize predictions with Grad-CAM

πŸ“ Citation

If you use this code in your research, please cite:

@article{hinton2015distilling,
  title={Distilling the knowledge in a neural network},
  author={Hinton, Geoffrey and Vinyals, Oriol and Dean, Jeff},
  journal={arXiv preprint arXiv:1503.02531},
  year={2015}
}

πŸ“„ License

This project is available for educational and research purposes under the MIT License.


πŸ™ Acknowledgments

  • 🧬 Brain Tumor MRI Dataset from Kaggle
  • 🧠 PyTorch and torchvision teams
  • πŸ”₯ Grad-CAM implementation from pytorch-grad-cam

🀝 Contributing

Contributions, issues, and feature requests are welcome! Feel free to check the issues page.

About

Brain tumor MRI classification using knowledge distillation, compressing a ResNet-101 teacher into an efficient MobileNetV3 student with high accuracy

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors