A comprehensive research project investigating how deep learning models behave under distribution shifts, with focus on attribution method robustness and model interpretability across different data distributions.
This project addresses a critical challenge in AI safety: What happens when deep learning models encounter data that differs from their training distribution? Through extensive experiments with Vision Transformers (ViT) and ResNet architectures, we demonstrate that:
- Models fail catastrophically on out-of-distribution (OOD) data while maintaining high confidence
- Attribution methods become unreliable precisely when interpretability is most needed
- Calibration breaks down creating dangerous "confident but wrong" scenarios
- Architecture matters for both performance and interpretability under distribution shift
- Multi-dataset evaluation (CIFAR-10 → CIFAR-100, SVHN)
- Performance, calibration, and attribution drift metrics
- Semantic coherence analysis across distribution shifts
- Saliency Maps, Grad-CAM, Integrated Gradients, Attention Rollout
- Architecture-specific compatibility analysis
- Correlation and similarity metrics for attribution drift detection
- 71.73% accuracy drop on OOD data with maintained 61.9% confidence
- 84.7% attribution dissimilarity between ID and OOD explanations
- 16.9× calibration degradation (ECE: 0.035 → 0.598)
| Model | ID Accuracy | OOD Accuracy | Attribution IoU | Calibration ECE |
|---|---|---|---|---|
| ViT | 72.74% | 1.01% | 0.153 | 0.598 |
| ResNet | 77.02% | 0.90% | 0.123 | 0.662 |
Comprehensive performance analysis showing catastrophic failure on OOD data
Attribution method reliability breakdown across distribution shifts
Visual comparison of attribution methods on ID vs OOD data
- Neither architecture is OOD-safe without additional safeguards
- ViT shows better attribution stability but worse overall performance
- ResNet offers computational efficiency but limited attribution consistency
- Attribution drift serves as OOD indicator (IoU < 0.2 signals distribution shift)
- Most Critical Finding: Attribution drift analysis (Phase 3) reveals whether explanations can be trusted when data distribution changes - crucial for real-world deployment safety
Try our interactive Colab notebook to see the critical safety findings in action:
This notebook demonstrates the key findings using cached experimental results - no training required!
git clone https://github.com/lauragomez/interpret_shifts.git
cd interpret_shifts
pip install -r requirements.txt# Test with ResNet (quick validation)
python experiments/ood_analysis/run_ood_experiments.py \
--model_path models/resnet_cifar10_best.pth \
--model_type resnet \
--quick_test \
--output_dir results/resnet_quick
# Test with ViT (quick validation)
python experiments/ood_analysis/run_ood_experiments.py \
--model_path models/vit-hf-scratch-small_cifar10_best.pth \
--model_type vit \
--quick_test \
--output_dir results/vit_quick# Train ViT from scratch
python experiments/training/main.py --model vit-hf-scratch --epochs 100 --lr 3e-4 --batch_size 128
# Train ResNet
python experiments/training/main.py --model resnet --epochs 50 --lr 1e-3 --batch_size 64# Full ResNet experiment on CIFAR-100
python experiments/ood_analysis/run_ood_experiments_cifar100.py \
--model_path models/resnet_cifar10_best.pth \
--model_type resnet \
--output_dir results/resnet_cifar100
# Full ViT experiment on CIFAR-100
python experiments/ood_analysis/run_ood_experiments_cifar100.py \
--model_path models/vit-hf-scratch-small_cifar10_best.pth \
--model_type vit \
--output_dir results/vit_cifar100
# Analyze ResNet on SVHN
python experiments/ood_analysis/run_ood_experiments.py \
--model_path models/resnet_cifar10_best.pth \
--model_type resnet \
--ood_dataset svhn \
--output_dir results/resnet_svhn# Compare multiple experiment results
python experiments/ood_analysis/compare_ood_results.py --results_dir results/# Create comprehensive analysis dashboard
python experiments/visualization/visualize_ood_results.py --results_dir results/
# Visualize CIFAR-100 specific results
python experiments/visualization/visualize_cifar100_results.py --results_dir results/Phase 1: Attribution Sanity Checks (~1-2 minutes)
- Tests if attribution methods are working correctly
- Uses simple linear models for validation
- Validates Saliency, Grad-CAM, Integrated Gradients, Attention Rollout
Phase 2: Performance & Calibration (~2-3 minutes)
- Evaluates model accuracy on CIFAR-10 (in-distribution)
- Evaluates model accuracy on OOD datasets (CIFAR-100, SVHN)
- Computes Expected Calibration Error (ECE)
Phase 3: Attribution Drift Analysis (~5-10 minutes) - Most Important
- Computes attributions on both ID and OOD data
- Measures how attribution patterns change (IoU, Pearson correlation)
- Key insight: Do explanations remain consistent under distribution shift?
Phase 4: Corruption Analysis (~5-15 minutes)
- Tests multiple corruption types (Gaussian noise, brightness, contrast)
- Multiple severity levels per corruption type
- Evaluates robustness under controlled distribution shifts
| Experiment Type | Quick Test | Full Experiment |
|---|---|---|
| ResNet | 3-5 min | 15-25 min |
| ViT | 3-5 min | 15-25 min |
| Both + Comparison | 8-12 min | 35-55 min |
If Model Files Missing:
# Check what models are available
python experiments/setup_experiments.py
# Train models if needed
python experiments/training/main.py --model resnet --epochs 50
python experiments/training/main.py --model vit --epochs 50If Running Out of Memory:
# Use smaller batch size
python experiments/ood_analysis/run_ood_experiments.py \
--model_path <path> --model_type <type> --batch_size 16
# Or use CPU
python experiments/ood_analysis/run_ood_experiments.py \
--model_path <path> --model_type <type> --device cpuIf Attribution Methods Fail:
- Some methods may fail on simple models (expected behavior)
- Real trained models should work better
- Results will still be generated with available methods
interpret_shifts/
├── README.md # Main project documentation
├── requirements.txt # Unified dependencies
├── requirements_ood.txt # OOD-specific dependencies
├── quick_demo.ipynb # One-click Colab demo
├── src/ # Core source code
│ ├── models/ # Model architectures
│ │ ├── __init__.py
│ │ ├── resnet.py # ResNet-18 implementation
│ │ └── vit.py # Vision Transformer implementation
│ └── utils/ # Utility functions
│ ├── __init__.py
│ ├── attribution_methods.py # Attribution computation methods
│ ├── datasets.py # Dataset loading utilities
│ ├── plot_utils.py # Plotting utilities
│ └── utils.py # General utilities
├── experiments/ # Experiment scripts
│ ├── training/ # Model training scripts
│ │ ├── main.py # Main training script
│ │ ├── main_gpu.py # GPU-optimized training
│ │ └── resnet_run.sh # ResNet training script
│ ├── ood_analysis/ # Out-of-distribution analysis
│ │ ├── analyze_cifar100_resnet.py # CIFAR-100 ResNet analysis
│ │ ├── analyze_svhn_resnet.py # SVHN ResNet analysis
│ │ ├── compare_ood_results.py # Results comparison
│ │ ├── run_ood_experiments.py # Main OOD experiment runner
│ │ ├── run_ood_experiments_cifar100.py
│ │ ├── run_ood_experiments_memory_optimized.py
│ │ └── vit_attribution_fix.py # ViT attribution fixes
│ ├── visualization/ # Visualization scripts
│ │ ├── generate_plots.py # Plot generation
│ │ ├── simple_visualize.py # Simple visualization
│ │ ├── visualize_cifar100_results.py
│ │ ├── visualize_ood_results.py # Main visualization script
│ │ └── visualize_saliency_maps_cifar100.py
│ ├── gpu_benchmark.py # GPU performance benchmarking
│ └── setup_experiments.py # Experiment setup utilities
├── scripts/ # Convenience scripts (symlinks to experiments/)
│ └── [mirrors experiments/ structure]
├── results/ # Experimental results
│ └── consolidated/ # All experimental results
│ ├── results_cifar100_resnet_gpu/
│ │ ├── analysis/ # Analysis reports and dashboards
│ │ ├── cifar100_ood_results_resnet_*.json
│ │ ├── cifar100_summary_resnet.csv
│ │ └── visualizations/ # Generated visualizations
│ ├── results_cifar100_vit/
│ │ ├── cifar100_ood_results_vit_*.json
│ │ ├── cifar100_summary_vit.csv
│ │ └── visualizations/ # ViT-specific visualizations
│ └── [other result directories]
├── docs/ # Documentation
│ ├── reports/ # Research reports
│ └── visualizations/ # Documentation visualizations
├── models/ # Trained model checkpoints
│ ├── resnet_cifar10_best.pth
│ └── vit-hf-scratch-small_cifar10_best.pth
└── examples/ # Usage examples
├── quick_start.py # Basic usage example
├── custom_analysis.py # Advanced analysis example
└── commands.txt # Useful commands
Source Code (src/)
- Models: ResNet-18 and Vision Transformer implementations
- Utils: Core utilities for training, evaluation, and analysis
- Attribution Methods: Saliency, Grad-CAM, Integrated Gradients, Attention Rollout
Experiments (experiments/)
- Training: Model training scripts with advanced optimization
- OOD Analysis: Comprehensive out-of-distribution evaluation
- Visualization: Analysis and plotting utilities
Results (results/consolidated/)
- Analysis reports and dashboards
- Generated visualizations and plots
- Model weights and checkpoints
- JSON results and CSV summaries
For Researchers:
- Start with this README for overview
- Run experiments:
experiments/for custom analysis - View results:
results/consolidated/for generated artifacts - Review code:
src/for implementation details
For Practitioners:
- Quick start:
examples/quick_start.py - Training:
experiments/training/main.py - Analysis:
experiments/ood_analysis/run_ood_experiments.py - Visualization:
experiments/visualization/visualize_ood_results.py
For Developers:
- Core code:
src/for implementation details - Experiments:
experiments/for methodology - Examples:
examples/for usage patterns - Results:
results/for validation
- In-Distribution: CIFAR-10 (10 classes, 32×32 images)
- Out-of-Distribution:
- CIFAR-100 (100 classes, similar domain)
- SVHN (different domain, overlapping classes)
Dataset Selection Rationale: We deliberately chose CIFAR-10/100 and SVHN as our evaluation benchmarks for several critical reasons:
- Controlled Environment: These datasets provide a controlled laboratory setting for mechanistic analysis of distribution shift effects
- Computational Efficiency: Enables extensive experimentation with multiple models, attribution methods, and hyperparameter sweeps
- Established Baselines: Well-studied datasets with known failure modes, allowing for systematic comparison with existing literature
- Reproducibility: Standardized datasets ensure experimental reproducibility and facilitate comparison across research groups
- Mechanistic Insights: The relatively simple visual features in these datasets make it easier to understand and interpret model behavior under distribution shift
- Vision Transformer (ViT): Transformer-based architecture
- ResNet-18: Convolutional neural network with residual connections
- Saliency Maps: Gradient-based pixel importance
- Grad-CAM: Class activation mapping (CNN-specific)
- Integrated Gradients: Path-integrated attributions
- Attention Rollout: Transformer attention visualization
- Performance: Accuracy, F1-score
- Calibration: Expected Calibration Error (ECE)
- Attribution Drift: IoU similarity, Pearson/Spearman correlation
- Semantic Analysis: Prediction distribution, confidence analysis
- Silent Failure Mode: Models fail catastrophically while appearing confident
- Explanation Unreliability: Attribution methods become meaningless on OOD data
- No Built-in Detection: Models lack mechanisms to identify distribution shift
- Systematic Biases: Predictable failure patterns that could be exploited
- ViT Strengths: Better attribution stability, superior calibration behavior
- ResNet Strengths: Computational efficiency, broad interpretability compatibility
- Both Limitations: Catastrophic OOD performance, unreliable confidence estimates
- Modular Design: Easy to extend with new models and attribution methods
- Comprehensive Evaluation: Multi-faceted analysis combining performance, calibration, and interpretability
- Reproducible Results: Fixed random seeds and complete parameter specifications
- Memory Optimization: Efficient data loading and GPU utilization
- Cross-dataset validation: Multiple OOD scenarios
- Corruption robustness: Performance under various corruption levels
- Semantic coherence: Analysis of prediction patterns
- Real-time monitoring: Training progress and convergence analysis
- Comprehensive OOD Analysis: Detailed performance and calibration breakdown
- Attribution Drift Study: Method-specific robustness analysis
- Architecture Comparison: ResNet vs ViT comprehensive evaluation
- Safety Assessment: Risk analysis for production deployment
- Performance Dashboards: Multi-metric comparison charts
- Attribution Drift Plots: IoU and correlation analysis
- Semantic Analysis: Prediction distribution and confidence patterns
- Calibration Assessment: Reliability diagrams and ECE analysis
- AI Safety Research: Understanding model limitations under distribution shift
- Interpretability Studies: Evaluating explanation method robustness
- Model Evaluation: Comprehensive assessment beyond accuracy metrics
- OOD Detection: Developing methods for distribution shift identification
- Medical AI: Ensuring reliable diagnosis across diverse patient populations
- Autonomous Systems: Safe deployment in novel environments
- Financial Systems: Robust risk assessment under changing market conditions
- Quality Assurance: Model reliability testing for production deployment
- Complete ResNet analysis on additional OOD datasets
- Implement proposed OOD detection methods
- Test calibration improvement techniques
- Extend to vision-language models
- Hybrid architectures combining CNN and Transformer strengths
- Attribution-aware training for consistent explanations
- Real-world deployment testing
- Regulatory framework development for OOD safety
- Attention Is All You Need - Transformer architecture
- ResNet - Residual networks
- Grad-CAM - Class activation mapping
- Integrated Gradients - Attribution method
- Distribution shift detection methods
- Model calibration techniques
- Attribution method robustness
- AI safety and reliability
We welcome contributions! Please see our Contributing Guidelines for details.
- New attribution methods
- Additional OOD datasets
- Improved visualization tools
- Documentation improvements
- Performance optimizations
This project is licensed under the MIT License - see the LICENSE file for details.
Laura Gomez - Initial work and comprehensive analysis
- GitHub: @lauragomez
- Email: laura.gomez@stanford.edu
- PyTorch team for the excellent deep learning framework
- Hugging Face for transformer model implementations
- Captum team for attribution method implementations
- CIFAR and SVHN dataset creators for providing evaluation benchmarks
For questions, suggestions, or collaboration opportunities:
- Email: laura.gomez@stanford.edu
- GitHub Issues: Create an issue
- Discussions: Join the conversation
Important Notice: This research demonstrates critical safety limitations in current deep learning models. Models should never be deployed in production without proper OOD detection and uncertainty quantification mechanisms.
Research Impact: This work provides both a sobering assessment of current AI limitations and a roadmap for building more reliable, interpretable, and safe AI systems.