Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/api/interpret.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ New to interpretability in PyHealth? Check out these complete examples:
- Test various distance kernels (cosine vs euclidean) and sample sizes
- Decode attributions to human-readable medical codes and lab measurements

**Grad-CAM Example:**

- ``examples/cxr/gradcam_cxr_tutorial.py`` - Demonstrates Grad-CAM for CNN-based medical image classification. Shows how to:

- Choose a target convolutional layer from a PyHealth image model
- Generate class-conditional heatmaps for chest X-ray images
- Overlay the Grad-CAM heatmap on the original image for interpretation
- Run the example from a dataset path without editing the source file

These examples provide end-to-end workflows from loading data to interpreting and evaluating attributions.

Attribution Methods
Expand All @@ -82,6 +91,7 @@ Attribution Methods
interpret/pyhealth.interpret.methods.integrated_gradients
interpret/pyhealth.interpret.methods.shap
interpret/pyhealth.interpret.methods.lime
interpret/pyhealth.interpret.methods.gradcam

Visualization Utilities
-----------------------
Expand Down
62 changes: 62 additions & 0 deletions docs/api/interpret/pyhealth.interpret.methods.gradcam.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
pyhealth.interpret.methods.gradcam
==================================

Overview
--------

Grad-CAM provides class-conditional heatmaps for CNN-based image
classification models in PyHealth. It uses gradients from a target
convolutional layer to highlight which image regions contributed most to the
selected prediction.

This method is intended for:

- CNN image classification models
- chest X-ray and other medical imaging workflows built on PyHealth image tasks
- models that return either ``logit`` or ``y_prob``

Usage Notes
-----------

1. **CNN model**: Grad-CAM requires a 4D convolutional activation map from
the target layer.
2. **Target layer**: You can pass either an ``nn.Module`` directly or a dotted
string path such as ``"model.layer4.1.conv2"``.
3. **Class selection**: If ``class_index`` is omitted, Grad-CAM uses the
predicted class. For single-output binary models, it attributes to that
scalar output.
4. **Gradients required**: Do not call ``attribute()`` inside
``torch.no_grad()``.
5. **Return shape**: ``attribute()`` returns ``{input_key: cam}`` where the CAM
tensor has shape ``[B, H, W]``.

Quick Start
-----------

.. code-block:: python

from pyhealth.interpret.methods import GradCAM
from pyhealth.interpret.utils import visualize_image_attr

gradcam = GradCAM(
model,
target_layer=model.model.layer4[-1].conv2,
input_key="image",
)
cams = gradcam.attribute(**batch)
image, heatmap, overlay = visualize_image_attr(
image=batch["image"][0],
attribution=cams["image"][0],
)

For a complete script example, see:
``examples/cxr/gradcam_cxr_tutorial.py``

API Reference
-------------

.. autoclass:: pyhealth.interpret.methods.gradcam.GradCAM
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
2 changes: 2 additions & 0 deletions docs/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ These examples are located in ``examples/cxr/``.
- Conformal prediction for COVID-19 CXR classification
* - ``cxr/cnn_cxr.ipynb``
- CNN for chest X-ray classification (notebook)
* - ``cxr/gradcam_cxr_tutorial.py``
- Grad-CAM for CNN-based chest X-ray classification
* - ``cxr/chestxray14_binary_classification.ipynb``
- Binary classification on ChestX-ray14 dataset (notebook)
* - ``cxr/chestxray14_multilabel_classification.ipynb``
Expand Down
181 changes: 181 additions & 0 deletions examples/cxr/gradcam_cxr_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""Grad-CAM tutorial for CNN-based chest X-ray classification in PyHealth.

Prerequisites:
- A local COVID-19 Radiography Database root passed with ``--root``

Notes:
- For meaningful class-specific visualizations, pass ``--checkpoint`` with a
trained PyHealth checkpoint. Without a checkpoint, the script still runs as a
pipeline example, but the classification head is randomly initialized.
- ``--weights DEFAULT`` may trigger a first-run torchvision download. Use
``--weights none`` for an offline run.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import matplotlib.pyplot as plt
import torch

from pyhealth.datasets import COVID19CXRDataset, SampleDataset, get_dataloader
from pyhealth.interpret.methods import GradCAM
from pyhealth.interpret.utils import visualize_image_attr
from pyhealth.models import TorchvisionModel


def parse_args() -> argparse.Namespace:
"""Parse command-line arguments for the Grad-CAM tutorial.

Returns:
argparse.Namespace: Parsed CLI arguments controlling dataset location,
model initialization, runtime device, and output path.
"""
parser = argparse.ArgumentParser(
description="Run Grad-CAM on one chest X-ray sample.",
)
parser.add_argument(
"--root",
required=True,
help="Path to the COVID-19 Radiography Database root directory.",
)
parser.add_argument(
"--checkpoint",
default=None,
help="Optional checkpoint to load before inference.",
)
parser.add_argument(
"--output",
default="gradcam_cxr_overlay.png",
help="Where to save the Grad-CAM figure.",
)
parser.add_argument(
"--device",
default=None,
help="Optional device override such as 'cpu' or 'cuda:0'.",
)
parser.add_argument(
"--weights",
choices=["DEFAULT", "none"],
default="DEFAULT",
help="Torchvision backbone weights to use when initializing resnet18.",
)
return parser.parse_args()


def resolve_device(device_arg: str | None) -> str:
"""Resolve the device string used for inference.

Args:
device_arg: Optional CLI override such as ``"cpu"`` or ``"cuda:0"``.

Returns:
str: The resolved device string.
"""
if device_arg is not None:
return device_arg
return "cuda:0" if torch.cuda.is_available() else "cpu"


def load_dataset(root: str) -> SampleDataset:
"""Load the COVID-19 CXR sample dataset for the tutorial.

Args:
root: Root directory containing the COVID-19 Radiography Database.

Returns:
SampleDataset: Task-applied sample dataset ready for dataloader use.

Raises:
SystemExit: If ``openpyxl`` is required but unavailable.
"""
try:
dataset = COVID19CXRDataset(root, num_workers=1)
return dataset.set_task(num_workers=1)
except ImportError as exc:
if "openpyxl" in str(exc):
raise SystemExit(
"This example needs 'openpyxl' to read the raw metadata sheets. "
"Install it with: pip install openpyxl"
) from exc
raise


def main() -> None:
"""Run Grad-CAM on a single chest X-ray sample and save a figure."""
args = parse_args()
root = Path(args.root).expanduser()
if not root.exists():
raise SystemExit(f"Dataset root does not exist: {root}")

sample_dataset = load_dataset(str(root))
loader = get_dataloader(sample_dataset, batch_size=1, shuffle=False)
batch = next(iter(loader))

weights = None if args.weights == "none" else "DEFAULT"
model = TorchvisionModel(
dataset=sample_dataset,
model_name="resnet18",
model_config={"weights": weights},
)
device = resolve_device(args.device)
model = model.to(device)
model.eval()

if args.checkpoint:
checkpoint_path = Path(args.checkpoint).expanduser()
if not checkpoint_path.exists():
raise SystemExit(f"Checkpoint does not exist: {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
print(f"Loaded checkpoint from {checkpoint_path}")
else:
print(
"Warning: no checkpoint provided. The classifier head is randomly "
"initialized, so this run is only a pipeline example."
)

with torch.no_grad():
y_prob = model(**batch)["y_prob"][0]

label_vocab = sample_dataset.output_processors["disease"].label_vocab
pred_class = int(torch.argmax(y_prob).item())
id2label = {value: key for key, value in label_vocab.items()}
pred_label = id2label[pred_class]

gradcam = GradCAM(
model,
target_layer=model.model.layer4[-1].conv2,
input_key="image",
)
cam = gradcam.attribute(class_index=pred_class, **batch)["image"]

image, heatmap, overlay = visualize_image_attr(
image=batch["image"][0],
attribution=cam[0],
)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(image, cmap="gray")
axes[0].set_title("Input")
axes[0].axis("off")

axes[1].imshow(heatmap, cmap="jet")
axes[1].set_title("Grad-CAM")
axes[1].axis("off")

axes[2].imshow(overlay)
axes[2].set_title(f"Overlay: {pred_label}")
axes[2].axis("off")

output_path = Path(args.output).expanduser()
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.tight_layout()
plt.savefig(output_path, dpi=150)
print(f"Predicted class: {pred_label}")
print(f"Saved Grad-CAM visualization to {output_path.resolve()}")


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion pyhealth/interpret/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyhealth.interpret.methods.ensemble_crh import CrhEnsemble
from pyhealth.interpret.methods.ensemble_avg import AvgEnsemble
from pyhealth.interpret.methods.ensemble_var import VarEnsemble
from pyhealth.interpret.methods.gradcam import GradCAM

__all__ = [
"BaseInterpreter",
Expand All @@ -25,5 +26,6 @@
"LimeExplainer",
"CrhEnsemble",
"AvgEnsemble",
"VarEnsemble"
"VarEnsemble",
"GradCAM",
]
Loading
Loading