This project provides a robust, end-to-end pipeline for multi-label classification of chest X-ray images using deep learning and transfer learning. The model predicts the presence of multiple thoracic diseases in a single X-ray image, leveraging modern transformer-based architectures.
- Features
- Technologies Used
- Project Structure
- Dataset
- Installation
- Usage
- Model Training and Evaluation
- Testing the Model
- Results
- Future Improvements
- License
- Acknowledgments
- Handles multi-label classification for medical images (14 thoracic disease classes).
- Utilizes Swin Transformer (or configurable backbone) with transfer learning via timm.
- Implements advanced evaluation metrics: Precision, Recall, F1-score, and AUC (macro).
- Includes threshold optimization for best validation precision.
- Robust error handling and logging for data loading and preprocessing.
- Modular, extensible codebase for research and experimentation.
- Python 3.8+
- PyTorch for deep learning model implementation.
- timm for state-of-the-art vision backbones (Swin Transformer, etc.).
- torchmetrics for evaluation metrics.
- Pandas and NumPy for data manipulation.
- scikit-learn (optional, for further analysis).
- Pillow for image processing.
.
|-- script.py # Main implementation file (PyTorch pipeline)
|-- testscript.py # Script for testing the model
|-- data.csv # Input dataset (CSV format)
|-- images/ # Directory containing chest X-ray images
|-- output/ # Directory for model checkpoints and logs
|-- README.md # Documentation
|-- requirements.txt # Python dependencies
The project expects a dataset of chest X-ray images annotated with multiple disease labels. Each image is associated with one or more binary labels indicating the presence of specific conditions.
- The CSV file (
data.csv) should contain:- A column
DicomPath_yspecifying the file path to each X-ray image. - Columns for each disease label (e.g.,
Atelectasis,Cardiomegaly, etc.), with values 0 (absent), 1 (present), or -1/NaN (treated as 0). - A
splitcolumn indicating the data split (train,validate,test).
- A column
| DicomPath_y | Atelectasis | Cardiomegaly | Pneumonia | ... | split |
|---|---|---|---|---|---|
| images/img1.png | 1 | 0 | 0 | ... | train |
| images/img2.png | 0 | 1 | 1 | ... | validate |
| images/img3.png | 0 | 0 | 0 | ... | test |
-
Clone the repository:
git clone https://github.com/your-repo/multi-label-xray-classifier.git cd multi-label-xray-classifier -
Install required dependencies:
pip install -r requirements.txt
-
Prepare the dataset:
- Place your
data.csvand images in the appropriate directories as described above.
- Place your
-
Configure settings
Edit theConfigclass inscript.pyto adjust model backbone, batch size, learning rate, etc. -
Run the training and evaluation pipeline:
python script.py
- The script will:
- Load and preprocess the data.
- Train the model with the specified configuration.
- Evaluate on validation and test sets.
- Optimize the classification threshold for best precision.
- Save the best model checkpoint and log results.
- The script will:
-
Check the
output/directory for saved models and logs.
- Model: Swin Transformer (default) or configurable backbone from
timm. - Loss: Binary Cross-Entropy with Logits (multi-label).
- Metrics: Precision, Recall, F1-score, AUC (macro, thresholded).
- Threshold Optimization: Automatically searches for the best threshold on the validation set.
- Logging: Detailed logs for each epoch, including metrics and checkpointing.
To evaluate a trained model checkpoint on the test set, use the testscript.py.
Command (Bash/Linux/macOS):
python d:\personal-projects\HealthCare\DataSets\chest-x-ray-dataset-with-lung-segmentation-1.0.0\testscript.py --checkpoint_path path/to/your/best_model.pthCommand (PowerShell/Windows):
python d:\personal-projects\HealthCare\DataSets\chest-x-ray-dataset-with-lung-segmentation-1.0.0\testscript.py --checkpoint_path path\to\your\best_model.pth
# Or using .\ if in the script's directory:
# python .\testscript.py --checkpoint_path path\to\your\best_model.pthExplanation:
--checkpoint_path: (Required) Specify the full path to the saved model checkpoint file (e.g.,output/best_model.pthoroutput\best_model.pthgenerated during training). Use the appropriate path separator for your OS (/for Linux/macOS,\for Windows).
Optional Arguments:
You can override settings stored in the checkpoint's configuration by providing additional arguments:
--data_dir: Path to the directory containing the data and CSV file (e.g.,.).--csv_file: Name of the CSV file (e.g.,data.csv).--batch_size: Batch size for evaluation (larger sizes might speed up testing).--device: Device to run evaluation on (e.g.,cuda,cpu).--threshold: Classification threshold to use (if you want to test a specific one instead of the one from the config).--num_workers: Number of data loading workers.--output_dir: Directory to save the test log file (defaults totest_output).
Example with overrides (Bash/Linux/macOS):
python d:\personal-projects\HealthCare\DataSets\chest-x-ray-dataset-with-lung-segmentation-1.0.0\testscript.py \
--checkpoint_path output/best_model.pth \
--data_dir . \
--csv_file data.csv \
--batch_size 32 \
--device cuda \
--output_dir test_run_1Example with overrides (PowerShell/Windows):
# Single line:
python d:\personal-projects\HealthCare\DataSets\chest-x-ray-dataset-with-lung-segmentation-1.0.0\testscript.py --checkpoint_path output\best_model.pth --data_dir . --csv_file data.csv --batch_size 32 --device cuda --output_dir test_run_1
# Multi-line using backticks (`):
python d:\personal-projects\HealthCare\DataSets\chest-x-ray-dataset-with-lung-segmentation-1.0.0\testscript.py `
--checkpoint_path output\best_model.pth `
--data_dir . `
--csv_file data.csv `
--batch_size 32 `
--device cuda `
--output_dir test_run_1The script will load the specified checkpoint, prepare the test dataset based on the configuration (or overrides), run the evaluation, and print the performance metrics (Loss, Precision, Recall, F1-Score, AUC, Per-Class F1, Confusion Matrix) to the console and the log file within the specified output directory.
-
Best validation precision and corresponding threshold are reported.
-
Final test set metrics (loss, precision, recall, F1, AUC) are logged.
-
Example (replace with your actual results):
Best validation precision: 0.82 at threshold 0.45 Test set results: Loss: 0.23 Precision: 0.81 Recall: 0.78 F1-score: 0.79 AUC: 0.88
- Add support for additional or custom backbones (e.g., ConvNeXt, EfficientNet).
- Implement data augmentation and advanced regularization.
- Add per-class threshold optimization.
- Integrate explainability (Grad-CAM, saliency maps).
- Support for DICOM image loading and preprocessing.
- Distributed/multi-GPU training.
This project is licensed under the MIT License. See the LICENSE file for details.
- timm for backbone models.
- Dataset contributors for their invaluable work in medical imaging.
- PyTorch and torchmetrics developers for their open-source tools.