This repository contains the implementation of a Triplet Variational Autoencoder (Tri-VAE) designed to detect anomalies in brain MRI scans. The method is inspired by the CVPR 2024 paper: "Triplet Variational Autoencoder for Unsupervised Anomaly Detection in Brain MRI".
Tri-VAE is trained to model the distribution of healthy brain MRIs and detect abnormal regions, such as tumors, by analyzing reconstruction errors. It combines variational inference with metric learning through a triplet loss to enhance the model's discriminative capacity for anomaly localization.
- Overview
- Model Architecture
- Dataset
- Implementation
- Evaluation
- Visual Results
- Limitations and Future Work
The Tri-VAE approach learns a compact latent space representation of healthy brain MRIs. It utilizes:
- Anchor: A clean, healthy MRI slice.
- Positive: Another healthy slice (noise-free).
- Negative: A healthy slice with added coarse noise or Simplex Noise.
By minimizing the distance between anchor and positive while maximizing the distance to the negative, the model learns to distinguish structural anomalies.
Core Components:
- Coarse and full-scale reconstruction branches
- Triplet loss
- L1 reconstruction loss
- KL divergence loss
- Structural Similarity Index (SSIM) loss
The trained model is evaluated on the BraTS dataset by comparing reconstruction errors. Anomalies (tumor regions) are expected to yield higher reconstruction errors.
The Tri-VAE architecture includes the following components:
- Encoder
- Coarse Decoder
- Full Decoder
- Triplet loss branch
- SSIM evaluation module
- IXI Dataset: Used for training on healthy MRI slices.
- BraTS Dataset: Used for evaluating tumor detection.
Data preprocessing includes:
- Skull stripping
- Normalization
- Resizing MRI slices
The model is implemented using PyTorch and structured as follows:
Each input is grouped as:
- Anchor (A): A healthy slice from IXI
- Positive (P): Another healthy slice without noise
- Negative (N): A healthy slice with added artificial noise
These are passed through a shared encoder to produce latent embeddings, and reconstruction is carried out at two scales:
- Coarse Decoder: Provides low-resolution reconstruction
- Full Decoder: Produces detailed output from coarse output + latent features
The training is guided by a combination of:
- Coarse Reconstruction L1 Loss for All Images – Pixel-wise reconstruction accuracy for all images
- Full Reconstruction L1 Loss for Negative Images – Pixel-wise reconstruction accuracy for negative samples
- KL Divergence – Regularization for the VAE
- Triplet Loss – Metric learning signal for embedding separation
- SSIM Loss – Structural similarity preservation
These losses are balanced to enforce latent structure learning and faithful reconstructions.
Evaluation focuses on anomaly detection performance:
- Reconstruction Error Maps: Input vs. reconstructed MRI slices are subtracted to produce error maps.
- Dice Score: Computed against ground truth tumor masks (BraTS) to evaluate detection accuracy.
- Visualizations: Tumor regions typically appear as high-error areas in the residuals.
Below are examples of anomaly localization via reconstruction error:
Normal VAE:
Triplet VAE:
- The model is currently limited to 2D slice-based analysis.
- No domain adaptation is performed between IXI and BraTS.
- Future extensions could include attention mechanisms, 3D volumetric models, and unsupervised domain adaptation.