An implementation of explainable AI techniques for image classification. CIAO identifies influential image regions by systematically segmenting images, obfuscating segments, and using search algorithms to find important regions (hyperpixels).
CIAO explains what regions of an image contribute to a neural network's classification decisions. The method:
- Segments the image into small regions
- Obfuscates each segment and measures impact on model predictions
- Uses search algorithms to group adjacent important segments into hyperpixels
- Generates explanations showing which regions influenced the prediction
# Clone the repository
git clone https://github.com/RationAI/ciao.git
cd ciao
# Install dependencies using uv
uv syncExplain a single image with default settings:
uv run ciaoCustomize the explanation using Hydra configuration overrides:
uv run ciao data.image_path=./my_image.jpg explanation.method=mcts explanation.segment_size=8Alternatively, run as a module:
uv run python -m ciaouv sync- Install all dependenciesuv add <package>- Add a new dependencyuv run ruff check- Run lintinguv run ruff format- Format codeuv run mypy .- Run type checkinguv run ciao- Run CIAO with default configurationuv run pytest tests- Execute tests
- Segmentation: The input image is divided into small regions (segments) using hexagonal or square grids
- Score Calculation: Each segment is obfuscated (replaced) and the model is queried to measure how much that segment affects the prediction. This gives an importance score to each segment
- Hyperpixel Search: A search algorithm finds groups of adjacent segments with high importance scores, creating "hyperpixels" that represent influential image regions
- Explanation: The top hyperpixels are visualized to show which regions most influenced the model's prediction
- MCTS (Monte Carlo Tree Search): Tree-based search with UCB exploration
- MC-RAVE: MCTS with Rapid Action Value Estimation
- MCGS (Monte Carlo Graph Search): Graph-based variant allowing revisiting of states
- MCGS-RAVE: MCGS with RAVE enhancements
- Lookahead: Greedy search with lookahead using efficient bitset operations
- Potential: Potential field-guided sequential search
- Hexagonal Grid: Divides image into hexagonal cells for better spatial coverage
- Square Grid: Simple square grid segmentation
- Mean Color: Replace masked regions with the image's mean color (normalized)
- Blur: Gaussian blur applied to masked regions
- Interlacing: Interlaced pattern replacement
- Solid Color: Replace with a specified solid color (RGB)
ciao/
├── ciao/ # Main package
│ ├── algorithm/ # Search algorithms and data structures
│ │ ├── mcts.py # Monte Carlo Tree Search
│ │ ├── mcgs.py # Monte Carlo Graph Search
│ │ ├── lookahead_bitset.py # Greedy lookahead with bitsets
│ │ ├── potential.py # Potential-based search
│ │ ├── bitmask_graph.py # Bitset operations for hyperpixels
│ │ ├── nodes.py # Node classes for tree/graph search
│ │ └── search_helpers.py # Shared MCTS/MCGS helper functions
│ ├── data/ # Data loading and preprocessing
│ │ ├── loader.py # Image loaders
│ │ ├── preprocessing.py # Image preprocessing utilities
│ │ └── segmentation.py # Segmentation utilities (hex/square grids)
│ ├── evaluation/ # Scoring and evaluation
│ │ ├── surrogate.py # Surrogate dataset creation and segment scoring
│ │ └── hyperpixel.py # Hyperpixel evaluation and selection
│ ├── explainer/ # Core explainer implementation
│ │ └── ciao_explainer.py # Main CIAO explainer class
│ ├── model/ # Model inference and predictions
│ │ └── predictor.py # ModelPredictor class for inference
│ ├── visualization/ # Visualization tools
│ │ ├── visualization.py # Interactive visualizations
│ │ └── visualize_tree.py # Tree/graph visualization utilities
│ └── __main__.py # CLI entry point
├── configs/ # Hydra configuration files
│ ├── ciao.yaml # Main entry point
│ ├── base.yaml # Base configuration
│ ├── data/ # Data configurations
│ │ └── default.yaml
│ ├── explanation/ # Explanation method configs
│ │ └── ciao_default.yaml # Default CIAO parameters
│ ├── hydra/ # Hydra settings
│ └── logger/ # Logger configurations
└── pyproject.toml # Project metadata and dependencies