This project is a benchmarking framework for comparing class-aware and traditional pruning methods on convolutional neural networks (CNNs). I implemented this for my master's thesis on benchmarking class-aware pruning techniques. The pruning algorithms can be evaluated on different datasets (CIFAR10, ImageNet, GTSRB) and model architectures (VGG16, ResNet18) by measuring accuracy, inference time and model size before and after pruning.
This is a schematic overview over the pruning pipeline. The in selection.py we implement selector objects that use different pruning algorithms for filters selection. The output is a dictionary with the indices of the filters that we are going to prune for each convolutional layer. The next part is the filter removal part where we remove the filters from the model. This is implemented in the pruner classes in pruner.py. Finally there is an optional retraining step using samples from the selected classes.
- Configuration Loading: Hydra config system loads dataset, model, pruning, and training parameters from
config/ - Model Loading:
models.get_model()retrieves torchvision models and adjusts output layers - Data Preparation:
DataLoaderFactorysubclasses (CIFAR10, ImageNet, GTSRB) create train/test/pruning subset loaders - Selection Phase:
selection.get_selector()factory returns pruning strategy instance that selects filters - Pruning Phase:
DepGraphPrunerapplies selected pruning of selected filter indices via torch.fx symbolic tracing - Evaluation:
metrics.measure_inference_time_and_accuracy()measures per-class accuracy and inference time
All pruning strategies inherit from PruningSelection abstract base and implement select(model) → dict of layer masks/indices:
- OCAP (
filter_selection/ocap.py): Computes activation statistics via forward hooks, applies ratios per-layer (class-aware) -> based on https://github.com/mzd2222/OCAP - LRP (
filter_selection/lrp.py): Layer-wise Relevance Propagation; uses backpropagation to compute relevance scores (class-aware) -> based on https://github.com/seulkiyeom/LRP_Pruning_toy_example - LnStructured: Prunes by layer norm magnitude (not class-aware)
- TorchPruner: Wraps torch_pruning library with Taylor/APoZ attribution metrics (not class-aware) --> used https://github.com/marcoancona/TorchPruner
Example: Pruning a VGG16 model trained on CIFAR10 with OCAP to 85% pruning ratio for classes 0,1,2:
python main.py model=vgg16 dataset=cifar10 training.train=false model.pretrained_weights_path=<PATH_TO_PRETRAINED_WEIGHTS>\
pruning=ocap \
pruning.pruning_ratio="[0.85]" selected_classes=[0,1,2] Further parameters can be adjusted in the config files or via CLI overrides.
All metrics are printed to console. If log_results=true, results are also saved logged to Weights & Biases.
- Override pattern: CLI args override YAML (e.g.,
training.retrain_after_pruning=true) - Config locations:
config/{pruning,model,dataset}/*.yaml+config/config.yamlbase
- Skip early layers:
cfg.model.skip_first_layersbypasses pruning first N conv layers as these are critical for feature extraction - Dealing with Skip Connection in ResNet:
filter_pruning_indices_for_resnet()inhelpers.pyis called for ResNets to ensure compatible pruning of skip connections (we don't prune the last conv layer in a block). If other archtitectures with skip connections are used, similar logic must be implemented - Last Layer Replacement: When
replace_last_layer=true, linear output layer is replaced so output dimension matches number ofselected_classes
- Data augmentation disabled by default:
use_data_augmentation=falsein config for consistent pruning results - Shuffle disabled in pruning loaders:
get_subset_dataloaders()setsshuffle=False - Seeding:
random.seed(42)inget_small_train_loader(); use explicit seed control for full reproducibility
- Multiple ratios:
cfg.pruning.pruning_ratiois a list; main.py loops over each ratio, creating separate pruned models. We do this to benchmark multiple pruning levels in one run. Otherwise there can be small variations in accuracy between the runs. - Ratio semantics: Fraction of filters pruned per layer (0.85 = prune 85%, keep 15%)
- Example:
pruning_ratio: [0.00, 0.85, 0.88, 0.90, ...]produces 11 pruned models
- Flag:
cfg.inference_with_onnx=trueconverts model to ONNX and benchmarks via onnxruntime - Purpose: Measure real-world inference speed on CPU/CPU platforms
- Caveat: Not all custom layers supported; falls back to PyTorch if conversion fails
| File | Purpose |
|---|---|
main.py |
Orchestrates full pipeline: training → selection → pruning → retraining → evaluation |
selection.py |
Abstract base PruningSelection and all filter selection algorithms |
pruner.py |
DepGraphPruner applies selected indices via torch.fx symbolic trace and StructurePruner implements a similar logic. |
models.py |
Model factory; handles torchvision load, last-layer replacement |
data_loader.py |
DataLoaderFactory subclasses for CIFAR10/ImageNet/GTSRB |
metrics.py |
Accuracy, inference time, parameter ratio, FLOP counting |
config/ |
Hydra YAML configs (base, pruning strategies, models, datasets) |
