Diagnostic autopsy for PyTorch training runs.
torchmortem hooks into your PyTorch training
loop and produces a "postmortem" diagnostic report telling you why your training
might be broken and how to fix it.
If you need full-fledged experiment tracking, hyperparameter sweeps, or collaborative dashboards, this is probably not the right tool. Instead, go look at platforms like Weights & Biases or TensorBoard.
Install into a virtual environment with uv:
uv venv .venv
source .venv/bin/activate
uv pip install -e ".[dev]"Or install from PyPI:
pip install torchmortemRequires Python >=3.10 and PyTorch >=2.0.
from torchmortem import Autopsy
with Autopsy(model, optimizer=optimizer) as autopsy:
for epoch in range(num_epochs):
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
autopsy.step(loss=loss.item())
autopsy.report("autopsy_report.html")torchmortem is built using a plugin architecture for maximum extensibility.
- Collectors (implementing
Collectorincollectors/base.py) attach PyTorch hooks to a model and record raw signals during training. - Detectors (implementing
Detectorindetectors/base.py) analyze the collected signals and return findings. - The interpreter (
DefaultInterpreterininterpreters/default.py, override-able via the protocol ininterpreters/base.py) synthesizes the findings from all detectors by applying the rules defined ininterpreters/rules. - The interpreter produces human-readable reports using renderers (defined in
renderers/).
Individual detectors:
- Vanishing / exploding gradients -- inter-layer gradient ratio analysis
- Dead units -- persistently inactive neurons (dead ReLU problem)
- Activation saturation -- sigmoid/tanh layers stuck in flat regions
- Unhealthy update ratios -- ||update||/||weight|| deviating from ~1e-3
- Loss dynamics -- catapult phase, edge-of-stability, plateaus, divergence
- Rank collapse -- representation dimensionality shrinking over training
- Weight norm pathologies -- explosion, stagnation, inter-layer imbalance
- Gradient noise -- SNR and batch size efficiency
Cross-signal insights (correlation rules):
- Gradient starvation -- vanishing gradients + dead units
- Instability feedback loop -- exploding gradients + weight explosion
- Representation bottleneck -- rank collapse + loss stagnation
- Curvature traps -- edge-of-stability + plateau
Report features:
- Executive summary -- 3-5 sentence assessment with the top recommendation
- Per-layer health scores -- 0-1 score for each layer, visualized as a heatmap
- Interactive charts -- loss curve, gradient norms, weight norms, update ratios, dead unit fractions, effective rank
- Cross-signal insights -- root-cause explanations synthesized from multiple detectors
- Findings -- each with severity, explanation, affected layers, remediation, and references
- JSON output -- for CI pipelines and programmatic analysis
Control the overhead/detail tradeoff:
# Presets
Autopsy(model, sampling="thorough") # max detail
Autopsy(model, sampling="balanced") # default
Autopsy(model, sampling="fast") # minimal overhead
# Granular control
from torchmortem import SamplingConfig
Autopsy(model, sampling=SamplingConfig(
default_interval=1,
expensive_interval=50,
overrides={"curvature": 20},
))See the examples/ directory:
basic_mlp.py-- Deep MLP with sigmoid activations (vanishing gradients, dead units)healthy_resnet.py-- Well-configured residual networktransformer_debug.py-- Transformer with high LR and no clippingcnn_overfit.py-- Small CNN that overfits on a toy image datasetlstm_vanishing.py-- Vanilla LSTM with extreme sequence length
Contributions are welcome! The plugin architecture aims to make it relatively easy for contributors to add new features. Please refer to CONTRIBUTING.md for guidelines.
Here is a complete example of a custom detector that flags any layer whose gradient norm exceeds a configurable threshold:
import numpy as np
from torchmortem.registry import register_detector
from torchmortem.types import CollectorState, Finding, RunMetadata, Severity
@register_detector
class LargeGradientDetector:
"""Flags layers where the gradient norm exceeds a fixed threshold."""
name: str = "large_gradient"
required_collectors: list[str] = ["gradient"]
def __init__(self, threshold: float = 100.0) -> None:
self._threshold = threshold
def analyze(
self,
collector_states: dict[str, CollectorState],
metadata: RunMetadata,
) -> list[Finding]:
grad_state = collector_states["gradient"]
norms = grad_state.series.get("grad_norm")
if norms is None or len(grad_state.steps) == 0:
return []
findings: list[Finding] = []
for idx, layer in enumerate(grad_state.layers):
layer_norms = norms[:, idx]
max_norm = float(np.max(layer_norms))
if max_norm > self._threshold:
findings.append(Finding(
detector=self.name,
severity=Severity.WARNING,
category="gradient_flow",
title=f"Large gradient in {layer}",
summary=f"Gradient norm in {layer} reached {max_norm:.1f}, exceeding the {self._threshold:.1f} threshold.",
detail=f"The maximum gradient L2 norm observed in {layer} was {max_norm:.1f}. Large gradients can destabilize training and cause weight explosion.",
affected_layers=[layer],
step_range=(int(grad_state.steps[0]), int(grad_state.steps[-1])),
remediation=[
"Add gradient clipping (torch.nn.utils.clip_grad_norm_).",
"Reduce the learning rate.",
],
))
return findingsThe detector will be picked up automatically once the module containing it
is imported before the Autopsy context manager is entered.
torchmortem is provided under the MIT License.