Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 79 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Implementation of [GradCAM](https://arxiv.org/pdf/1610.02391): Visualize the wei
|---|---|
| `Normal` | Standard CNN-based models (ResNet, etc.) |
| `ViT` | Vision Transformer models |
| `DeiT` | vision transformer structured model with one extra distillation token embedding row |
| `SwinT` | Swin Transformer models |

## Installation
Expand All @@ -29,13 +30,12 @@ pip install torch torchvision timm matplotlib pillow numpy
```python
GradCAM(
model, # torch.nn.Module — the model to visualize
layer_name, # str — full name of the name intended for GradCAM
img_path=None, # str — path to the input image
img_value=None, # tensor — pre-loaded image tensor (alternative to img_path)
layer_idx=2, # int — index of the layer to visualize before
input_shape=(224, 224), # tuple — image resize shape (used when no transform provided)
model_type='Normal', # str — 'Normal', 'ViT', or 'SwinT'
transform=None, # transforms.Compose — custom preprocessing (recommended for ViT/SwinT)
auto_find_classfier=False, # bool — auto-detect classifier head by name containing 'fier'
verbose=False, # bool — print debug info about shapes and predictions
)
```
Expand All @@ -56,10 +56,20 @@ Displays a 2×2 figure: original image, overlapped colormap, raw heatmap, and pr
---

## Examples
### Get access to the layer names:

```python
from grad_cam_code.grad_cam import *
model = create_model('timm/resnet18.a1_in1k', pretrained=True)
model.eval()
print_layername(model)
```

### ResNet (CNN)

> **Note:** ResNet ends with an `AdaptiveAvgPool` before the classifier — use `layer_idx=-2` to skip it, or `-1` will raise a dimension error.
> **Note:** ResNet ends with an `AdaptiveAvgPool` before the classifier — use `layer4.1.conv2` instead of last layer of the backbone, or it will raise a dimension error.

> **Note:** It is important to implement `model.eval()` to make GradCAM success. This action will make the prediction stable and disable some training actions (e.g.: BatchNorm, Dropout).

```python
from grad_cam_code.grad_cam import *
Expand All @@ -69,14 +79,14 @@ model.eval()

img_path = 'graphs/test_images/test2-pug-dog.png'

cam = GradCAM(model, img_path, layer_idx=-2, model_type='Normal')
cam(heatmap_threshold=100)
cam_vit = GradCAM(model,img_path, layer_name='layer4.1.conv2', model_type='Normal')
cam_vit(heatmap_threshold=20)
cam.imposing_visualization()
```

### Vision Transformer (ViT)

> **Note** ViT ends by taking only the [cls] patch of the backbone (encoder) into encoder, may need modify for better visualization result
> **Just Note** ViT ends by taking only the [cls] patch of the backbone (encoder) into classfication header.
```python
from grad_cam_code.grad_cam import *
from timm.data.transforms_factory import create_transform
Expand All @@ -89,11 +99,34 @@ model.eval()

img_path = 'graphs/test_images/test2-pug-dog.png'

cam = GradCAM(model, img_path, layer_idx=-4, model_type='ViT', transform=transform)
cam(heatmap_threshold=5)
cam_vit = GradCAM(model,img_path, layer_name='blocks.10.drop_path2', model_type='ViT', transform = transform, verbose=True)
cam_vit(heatmap_threshold=5)
cam.imposing_visualization()
```

### DeiT

```python
from grad_cam_code.grad_cam import *
from timm.data.transforms_factory import create_transform
from timm.data import resolve_data_config

model = create_model('timm/deit_small_distilled_patch16_224.fb_in1k', pretrained=True)
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
model.eval()

img_path = 'graphs/test_images/test2-pug-dog.png'

cam = GradCAM(model,img_path,layer_name='blocks.10.drop_path2', model_type='DeiT', transform = transform,verbose=True)
cam(heatmap_threshold=8)
# Specify denormalize to undo ImageNet normalization when saving
cam.imposing_visualization(
save_path="img/swt_test",
denormalize=([0.4850, 0.4560, 0.4060], [0.2290, 0.2240, 0.2250])
)
```

### Swin Transformer

```python
Expand All @@ -108,9 +141,8 @@ model.eval()

img_path = 'graphs/test_images/test2-pug-dog.png'

cam = GradCAM(model, img_path, layer_idx=-1, model_type='SwinT',
auto_find_classfier=True, transform=transform)
cam(heatmap_threshold=20)
cam = GradCAM(model,img_path,layer_name='layers.3.blocks.1.drop_path2', model_type='SwinT', transform = transform,verbose=True)
cam(heatmap_threshold=40)
# Specify denormalize to undo ImageNet normalization when saving
cam.imposing_visualization(
save_path="img/swt_test",
Expand All @@ -131,6 +163,39 @@ Grad-CAM computes a class-discriminative localization map by:
5. Upsampling the resulting heatmap back to the original image size via bilinear interpolation and overlaying it using a jet colormap.

The `model_type` parameter controls how activations are reshaped:
- `Normal`: expects `(B, C, H, W)` feature maps (standard CNN conv outputs).
- `ViT`: reshapes the sequence dimension `(B, HW, C)` into a spatial grid `(B, H, W, C)`.
- `SwinT`: uses the output as-is (already in spatial format).
- `Normal` (CNN-based): expects `(B, C, H, W)` feature maps (standard CNN conv outputs).
- `ViT/SwinT/DeiT` (transformer-based): reshapes the sequence dimension `(B, HW, C)` into a spatial grid `(B, H, W, C)`.

## Reference:
Theoretical support:
```
@article{Selvaraju_2019,
title={Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization},
volume={128},
ISSN={1573-1405},
url={http://dx.doi.org/10.1007/s11263-019-01228-7},
DOI={10.1007/s11263-019-01228-7},
number={2},
journal={International Journal of Computer Vision},
publisher={Springer Science and Business Media LLC},
author={Selvaraju, Ramprasaath R. and Cogswell, Michael and Das, Abhishek and Vedantam, Ramakrishna and Parikh, Devi and Batra, Dhruv},
year={2019},
month=oct, pages={336–359} }
```

Part of the code implementation happened during my research on this paper:
```
@article{CHEN2024100332,
title = {A vision transformer machine learning model for COVID-19 diagnosis using chest X-ray images},
journal = {Healthcare Analytics},
volume = {5},
pages = {100332},
year = {2024},
issn = {2772-4425},
doi = {https://doi.org/10.1016/j.health.2024.100332},
url = {https://www.sciencedirect.com/science/article/pii/S2772442524000340},
author = {Tianyi Chen and Ian Philippi and Quoc Bao Phan and Linh Nguyen and Ngoc Thang Bui and Carlo daCunha and Tuy Tan Nguyen},
keywords = {Computer-aided diagnosis, Machine learning, Vision transformer, Efficient neural networks, COVID-19, Chest X-ray},
}
```

191 changes: 0 additions & 191 deletions README.old.md

This file was deleted.

Binary file modified grad_cam_code/__pycache__/grad_cam.cpython-312.pyc
Binary file not shown.
Binary file not shown.
Loading