This is the repository of BraTS-ReportX, a paired resource of 257 clinical reports aligned to BraTS subjects, structured into a rich set of qualitative and quantitative attributes. This repository contains code for: (1) automatic generation of quantitative report attributes from BraTS data, including anatomical localization and geometric measurements; (2) report encoding with biomedical language models; (3) evaluating the semantic coverage and overall quality of the dataset, supporting analyses of how well BraTS-ReportX captures clinically relevant report information compared with existing resources; and (4) training and testing of the proposed vision-text alignment framework for 3D tumor segmentation. The codebase is designed to support reproducibility and further research on integrating structured clinical semantics into medical image segmentation.
Overview of the annotation protocol. Clinician reports and automatically generated reports are produced independently and then concatenated.
Our segmentation pipeline overview. Encoder features from a 3D U-Net are projected into flat visual embeddings, while clinical reports are mapped to text embeddings. A contrastive vision-text module aligns both modalities during training, while inference relies only on the image backbone.
Brain-Segmentation/
├── base/ # Abstract base classes
│ ├── base_dataset2d_sliced.py
│ ├── base_dataset.py
│ ├── base_model.py
│ └── base_trainer.py
├── config/ # Configuration files for training and transforms
│ ├── config_atlas.json
│ ├── atlas_transforms.json
│ └── ...
├── datasets/ # Dataset loading and preprocessing (inherited from base_datasets)
│ ├── DatasetFactory.py
│ ├── ATLAS.py
│ └── BraTS2D.py
├── losses/ # Loss function implementations
│ └── LossFactory.py
├── metrics/ # Metrics computation and tracking
│ ├── MetricsFactory.py
│ └── MetricsManager.py
├── models/ # Model architectures (inherited from base_model)
│ ├── ModelFactory.py
│ ├── UNet2D.py
│ └── UNet3D.py
├── optimizers/ # Optimizer configurations
│ └── OptimizerFactory.py
├── trainer/ # Training logic (inherited from base_trainer)
│ ├── trainer_2Dsliced.py
│ └── trainer_3D.py
├── transforms/ # Data augmentation and preprocessing
│ └── TransformsFactory.py
├── utils/ # Utility functions
│ ├── util.py
│ └── pad_unpad.py
├── scripts/ # Utility scripts (e.g., text embeddings)
│ └── extract_textemb_biobert.py
│ └── preprocess_qatacov.py
├── report_generation/ # Code for report generation and evaluation
│ ├── agreement
│ │ ├── radfact.py
│ │ ├── radfact-70b
│ │ ├── run-auto-agreement.py
│ │ └── run-radfact-agreement.py
│ ├── autogen
│ │ ├── fill_atlas.py
│ │ ├── generate_atlas.py
│ │ ├── generate_eloquent.py
│ │ ├── generate_reports
│ │ └── run_cc_segmentation.py
│ ├── utils
│ │ ├── data.py
│ │ ├── geometries.py
│ │ └── jsonify.py
│ └── README.md
├── config.py # Config file handler
├── main.py # Training entry point
└── requirements.txt # Python dependencies
This project was run and tested on python 3.11, cuda
- Clone the repository:
git clone https://github.com/kev98/Medical-Image-Segmentation.git
cd Medical-Image-Segmentation- Create and activate a virtual environment:
python3.11 -m venv .venv
source .venv/bin/activate- Install dependencies:
pip install -r requirements.txtThe following are some base examples. You can add other CLI parameters useful for your main.py (which must be the entrypoint for training).
Command line arguments implemented in the provided main.py file:
--config: Path to configuration JSON file (required)--epochs: Number of training epochs (required)--save_path: Directory to save model checkpoints (required)--trainer: Trainer class name (required)--validation: Enable validation during training (flag)--val_every: Run validation every N epochs (default: 1)--resume: Resume training from last checkpoint (flag)--debug: Enable debug mode with verbose output (flag)--eval_metric_type: Metric type for model selection -mean(per-class mean) oraggregated_mean(aggregated regions mean) (default:mean)--wandb: Enable Weights & Biases logging (flag). Run name will beconfig.name. Set project and entity with environment variables:export WANDB_ENTITY="your_entity"andexport WANDB_PROJECT="your_project"--mixed_precision: Enable mixed precision training:fp16orbf16(default: None, so training is performed with FP32 precision)--seed: random seed for reproducibility (default: 42)
Example of launch of main.py, training a 3D segmentation model, resuming checkpoints,
source /path_to_your_venv/bin/activate
export WANDB_ENTITY="name_of_your_entity"
export WANDB_PROJECT="name_of_your_project"
python main.py \
--config config/config_atlas.json \
--epochs 100 \
--save_path /folder_containing_model_last.pth \
--trainer Trainer_3D \
--validation \
--val_every 2 \
--resume \
--wandbTo set up a complete training pipeline, follow these steps:
-
Create a Dataset Class: Inherit from BaseDataset or BaseDataset2DSliced and implement the required abstract methods.
-
Implement a Model: Create your custom model by inheriting from BaseModel and implementing the
forward()method. -
Implement a Trainer: Create a custom trainer by inheriting from BaseTrainer and implementing
_train_epoch()andeval_epoch()methods. -
Create an Entrypoint: Write a
main.pyfile that loads your configuration and instantiates your trainer. Use the provided main.py as a template or reference.
For detailed documentation on each component, refer to the README files in their respective directories:
- Base Classes - Abstract base classes for datasets, models, and trainers
- Configuration - JSON configuration files for training and transforms
- Datasets - Dataset loading and preprocessing
- Losses - Loss function implementations
- Metrics - Metrics computation and tracking
- Models - Model architectures
- Optimizers - Optimizer configurations
- Trainers - Training logic
- Transforms - Data augmentation and preprocessing
- Utils - Utility functions
- For patch-based training with 3D volumes, the framework uses TorchIO's Queue and GridSampler.
- Metrics are automatically computed per-class and averaged.
- Checkpoints are saved as
model_last.pthandmodel_best.pthin the folder specified by the parameter --save_path. - The framework is compatible with PyTorch 2.3+ and uses TorchIO's SubjectsLoader for proper data handling.

